// Tensorflow
import * as tf from "@tensorflow/tfjs";
import { GraphModel, Tensor, Rank, Tensor2D } from "@tensorflow/tfjs";

// GoMap
import { Detection, ModelKey, ObjectDetectionModel } from "@models/detection.model";

export class YoloModel implements ObjectDetectionModel {
    public readonly width: number;
    public readonly height: number;

    constructor(
        public readonly model: GraphModel,
        public readonly classes: ReadonlyArray<string>,
        public readonly key: ModelKey
    ) {
        const shape = model.inputs[0].shape;
        if (!shape) {
            throw new Error("Invalid model input shape");
        }
        const [, width, height,] = shape;
        this.width = width;
        this.height = height;
    }

    private _imgToTensor(bitmap: ImageBitmap) {
        return tf.tidy(() => {
            const imgTensor = tf.browser.fromPixels(bitmap)
            const [h, w] = imgTensor.shape.slice(0, 2);
            const maxSize = Math.max(w, h);
            const imgPadded = <tf.Tensor3D>imgTensor.pad([
                [0, maxSize - h], // padding y [bottom only]
                [0, maxSize - w], // padding x [right only]
                [0, 0],
            ]);

            return tf.image
                .resizeBilinear(imgPadded, [this.width, this.height]) // resize frame
                .div(255.0) // normalize
                .expandDims(0);
        });
    }

    private *_rawDetectionsToObject(box_data: Float32Array, score_data: Float32Array, class_data: Int32Array, bitmap: ImageBitmap): Generator<Detection> {
        const imgScaleFactor = Math.max(bitmap.width / this.width, bitmap.height / this.height)

        for (let i = 0; i < score_data.length; i++) {
            const boxStart = i * 4;

            const x = Math.max(0, box_data[boxStart + 1] * imgScaleFactor);
            const y = Math.max(0, box_data[boxStart] * imgScaleFactor)
            let width = (box_data[boxStart + 3] - box_data[boxStart + 1]) * imgScaleFactor;
            if (x + width > bitmap.width) {
                width = bitmap.width - x;
            }

            let height = (box_data[boxStart + 2] - box_data[boxStart]) * imgScaleFactor;
            if (y + height > bitmap.height) {
                height = bitmap.height - y
            }

            yield {
                x: Math.round(x),
                y: Math.round(y),
                width: Math.round(width),
                height: Math.round(height),
                score: score_data[i],
                classifier: this.classes[class_data[i]],
                imgWidth: bitmap.width,
                imgHeight: bitmap.height
            }
        }
    }

    private async _tensorToDetections(tensor: Tensor<Rank>, bitmap: ImageBitmap, maxSize: number, iouThreshold: number, scoreThreshold: number) {
        tensor = tensor.transpose([0, 2, 1])
        const boxes = <Tensor2D>tf.tidy(() => {
            const w = tensor.slice([0, 0, 2], [-1, -1, 1]); // get width
            const h = tensor.slice([0, 0, 3], [-1, -1, 1]); // get height
            const x1 = tf.sub(tensor.slice([0, 0, 0], [-1, -1, 1]), tf.div(w, 2)); // x1
            const y1 = tf.sub(tensor.slice([0, 0, 1], [-1, -1, 1]), tf.div(h, 2)); // y1
            return tf.concat([
                y1,
                x1,
                tf.add(y1, h), //y2
                tf.add(x1, w), //x2
            ], 2).squeeze();
        });

        const [scores, classes] = tf.tidy(() => {
            // class scores
            const rawScores = tensor.slice([0, 0, 4], [-1, -1, this.classes.length]).squeeze([0]); // #6 only squeeze axis 0 to handle only 1 class models
            return [rawScores.max(1), rawScores.argMax(1)];
        })

        const nms = await tf.image.nonMaxSuppressionAsync(boxes, scores, maxSize, iouThreshold, scoreThreshold);

        const task = await Promise.all([
            <Promise<Float32Array>>boxes.gather(nms, 0).data(),
            <Promise<Float32Array>>scores.gather(nms, 0).data(),
            <Promise<Int32Array>>classes.gather(nms, 0).data()
        ])

        try {
            return this._rawDetectionsToObject(task[0], task[1], task[2], bitmap);
        } finally {
            tf.dispose([tensor, boxes, scores, classes, nms]);
        }
    }

    async warmup(): Promise<void> {
        console.debug("Warming up...")
        tf.engine().startScope();

        try {
            const tensor = tf.zeros([1, this.width, this.height, 3])
            const results = <Tensor<Rank>>this.model.execute(tensor)
            await results.data();
        } finally {
            tf.engine().endScope();
            console.debug("Model has warmed up");
        }
    }

    async *detect(bitmap: ImageBitmap): AsyncGenerator<Detection> {
        let start = Date.now();
        tf.engine().startScope();

        try {
            const tensor = this._imgToTensor(bitmap)
            const results = <Tensor<Rank>>this.model.execute(tensor)
            yield* await this._tensorToDetections(results, bitmap, 20, 0.45, 0.2)
        } finally {
            tf.engine().endScope();
            console.debug(`Inference took ${(Date.now() - start) / 1000} seconds`);
        }
    }

    public dispose() {
        this.model.dispose();
    }
}
