They OP updated the codepen link they shared with a working solution.
I’ve pasted the JavaScript code from URL (incase it’s unavailable for any reason), expand the details
to see the code.
const labels = [
// list of class names (removed for brevity)
];
const numClasses = labels.length;
const IMAGE_SIZE = 640;
tflite.setWasmPath(
"https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-tflite@0.0.1-alpha.10/wasm/"
);
// Helper function to draw bounding boxes on the canvas.
function drawBoxes(boxes_data, scores_data, classes_data) {
const canvas1 = document.getElementById("canvas-overlay");
const ctx = canvas1.getContext("2d");
ctx.clearRect(0, 0, canvas1.width, canvas1.height);
// font configs
const font = `${Math.max(
Math.round(Math.max(ctx.canvas.width, ctx.canvas.height) / 40),
14
)}px Arial`;
ctx.font = font;
ctx.textBaseline = "top";
for (let i = 0; i < scores_data.length; ++i) {
// filter based on class threshold
const klass = labels[classes_data[i]];
const color = "#f00";
const score = (scores_data[i] * 100).toFixed(1);
let [y1, x1, y2, x2] = boxes_data.slice(i * 4, (i + 1) * 4);
x1 *= IMAGE_SIZE;
x2 *= IMAGE_SIZE;
y1 *= IMAGE_SIZE;
y2 *= IMAGE_SIZE;
const width = x2 - x1;
const height = y2 - y1;
// draw box.
ctx.strokeStyle = "#f00";
ctx.fillStyle = "transparent";
ctx.fillRect(x1, y1, width, height);
// draw border box.
ctx.strokeStyle = color;
ctx.lineWidth = 2;
ctx.strokeRect(x1, y1, width, height);
// Draw the label background.
const textWidth = ctx.measureText(klass + " - " + score + "%").width;
const textHeight = parseInt(font, 10); // base 10
const yText = y1 - (textHeight + ctx.lineWidth);
ctx.fillRect(
x1 - 1,
yText < 0 ? 0 : yText, // handle overflow label box
textWidth + ctx.lineWidth,
textHeight + ctx.lineWidth
);
// Draw labels
ctx.fillStyle = "#ffffff";
ctx.fillText(klass + " - " + score + "%", x1 - 1, yText < 0 ? 0 : yText);
}
}
async function start() {
try {
// Load .tflite model from personal portfolio, mimicing how we import the model
// from our react-native app's assets folder.
// According to netron, this model has an input tensor shape if [1, 640, 640, 3],
// and an output tensor shape of [1, 84, 8400]
const tfliteModel = await tflite.loadTFLiteModel(
"https://drewjamesandre.com/.well-known/yolov8s-oiv7_float32.tflite"
);
console.log("Model loaded!");
// Prepare input tensor data from image
const imageElement = document.getElementById("original-image");
const tensor = tf.browser.fromPixels(imageElement);
// The following block of code simply draws the model's input image
// just for debugging purposes, really (in case our issue is due to
// some malformed image that is passed to the model)
const canvas = document.getElementById("canvas-overlay");
tf.browser.draw(tensor, canvas);
const dataUrl = canvas.toDataURL();
const img = document.getElementById("model-image");
img.src = dataUrl;
// Convert the pixel values to [0, 1].
const input = tf.div(tf.expandDims(tensor), 255);
// Run the model prediction, and save the resulting tensor output data
const outputTensor = tfliteModel.predict(input);
// transpose result [b, det, n] => [b, n, det]
const transRes = outputTensor.transpose([0, 2, 1]);
// All of the box processing is borrowed from https://github.com/Hyuto/yolov8-tfjs/blob/master/src/utils/detect.js
const boxes = tf.tidy(() => {
const w = transRes.slice([0, 0, 2], [-1, -1, 1]); // get width
const h = transRes.slice([0, 0, 3], [-1, -1, 1]); // get height
const x1 = tf.sub(transRes.slice([0, 0, 0], [-1, -1, 1]), tf.div(w, 2)); // x1
const y1 = tf.sub(transRes.slice([0, 0, 1], [-1, -1, 1]), tf.div(h, 2)); // y1
return tf
.concat(
[
y1,
x1,
tf.add(y1, h), // y2
tf.add(x1, w) // x2
],
2
)
.squeeze();
}); // process boxes [y1, x1, y2, x2]
const [scores, classes] = tf.tidy(() => {
// class scores
const rawScores = transRes
.slice([0, 0, 4], [-1, -1, numClasses])
.squeeze(0); // #6 only squeeze axis 0 to handle only 1 class models
return [rawScores.max(1), rawScores.argMax(1)];
}); // get max scores and classes index
const nms = await tf.image.nonMaxSuppressionAsync(
boxes,
scores,
20,
0.5,
0.2
); // NMS to filter boxes
const boxes_data = boxes.gather(nms, 0).dataSync(); // indexing boxes by nms index
const scores_data = scores.gather(nms, 0).dataSync(); // indexing scores by nms index
const classes_data = classes.gather(nms, 0).dataSync(); // indexing classes by nms index
// Draw boxes to the existing canvas, which has a background image containing the
// image supplied to the model
drawBoxes(boxes_data, scores_data, classes_data);
// Clean up tensors.
tf.dispose([outputTensor, transRes, boxes, scores, classes, nms]); // clear memory
// Remove loading state
document.getElementById("loading-text").remove();
} catch (error) {
console.warn("Failed to predict:", error.message);
}
}
tf.setBackend("cpu")
.then(() => start())
.catch(console.warn);