import createTFLiteModule, {TFLite} from "../tflite/tflite.js";
import createTFLiteSIMDModule from "../tflite/tflite-simd.js";
import {getTFLiteModelFileName, SegmentationConfig} from "./SegmentationHelper";
export type {TFLite};

export async function loadTFLite(isSIMDSupported: boolean) {
    if (isSIMDSupported) {
        return createTFLiteModule({
            locateFile: () => process.env.PUBLIC_URL + "/tflite/tflite.wasm",
        });
    }
    return createTFLiteSIMDModule({
        locateFile: () => process.env.PUBLIC_URL + "/tflite/tflite-simd.wasm",
    });
}

export async function loadTFLiteModel(tflite: TFLite, segmentationConfig: SegmentationConfig) {
    const modelFileName = getTFLiteModelFileName(segmentationConfig.model, segmentationConfig.inputResolution);
    console.log("Loading tflite model:", modelFileName);
    const modelResponse = await fetch(`${process.env.PUBLIC_URL}/virtual-background/models/${modelFileName}.tflite`);
    const model = await modelResponse.arrayBuffer();
    const modelBufferOffset = tflite._getModelBufferMemoryOffset();
    tflite.HEAPU8.set(new Uint8Array(model), modelBufferOffset);
    console.log("_loadModel result:", tflite._loadModel(model.byteLength));
    console.log("Input height:", tflite._getInputHeight());
    console.log("Input width:", tflite._getInputWidth());
    console.log("Output height:", tflite._getOutputHeight());
    console.log("Output width:", tflite._getOutputWidth());
}
