YOLOv8 hand detection model: works in PyTorch but fails after conversion to TensorFlow.js

Hello,
I’m looking for help with my machine learning model that detects my hand and one wrist keypoint. After training, the model correctly detects my hand with a bounding box and wrist keypoint in PyTorch. However, after converting the best.pt file to a TensorFlow.js model, the detection fails it no longer detects my hand or the keypoint.

Model Details

YOLOv8 trained for pose detection
Custom dataset with hand images and wrist keypoint annotations
Input size: 224x224
The model works correctly in PyTorch environment

Here is how i did my convertion

import os
from ultralytics import YOLO
import shutil
import tensorflow as tf
from google.colab import files

def find_saved_model(base_path):
    """Find the SavedModel directory in the export path"""
    for root, dirs, files in os.walk(base_path):
        if 'saved_model.pb' in files:
            return root
    return None

def add_signatures(saved_model_dir):
    """Load the SavedModel and add required signatures"""
    print("Adding signatures to SavedModel...")

    # Load the model
    model = tf.saved_model.load(saved_model_dir)

    # Create a wrapper function that matches the model's interface
    @tf.function(input_signature=[
        tf.TensorSpec(shape=[1, 640, 640, 3], dtype=tf.float32, name='images')
    ])
    def serving_fn(images):
        # Call model directly without training parameter
        return model(images)

    # Convert the model
    concrete_func = serving_fn.get_concrete_function()

    # Create a new SavedModel with the signature
    tf.saved_model.save(
        model,
        saved_model_dir,
        signatures={
            'serving_default': concrete_func
        }
    )

    print("Signatures added successfully")
    return saved_model_dir

def convert_to_tfjs(pt_model_path, output_dir):
    """
    Convert a PyTorch YOLO model to TensorFlow.js format

    Args:
        pt_model_path (str): Path to the .pt file
        output_dir (str): Directory to save the converted model
    """
    try:
        # Ensure output directory exists
        os.makedirs(output_dir, exist_ok=True)

        # Load the model
        print(f"Loading YOLO model from {pt_model_path}...")
        model = YOLO(pt_model_path)

        # First export to TensorFlow format
        print("Exporting to TensorFlow format...")

        # Export the model
        success = model.export(
            format='saved_model',
            imgsz=672,
            half=False,
            simplify=True
        )

        # Find the SavedModel directory
        saved_model_dir = find_saved_model(os.path.join(os.getcwd(), "best_saved_model"))
        if not saved_model_dir:
            raise Exception(f"Cannot find SavedModel directory in {os.path.dirname(pt_model_path)}")

        print(f"Found SavedModel at: {saved_model_dir}")

        # Add signatures to the model
        saved_model_dir = add_signatures(saved_model_dir)

        # Convert to TensorFlow.js
        print("Converting to TensorFlow.js format...")
        tfjs_target_dir = os.path.join(output_dir, 'tfjs_model')

        # Ensure clean target directory
        if os.path.exists(tfjs_target_dir):
            shutil.rmtree(tfjs_target_dir)
        os.makedirs(tfjs_target_dir)

        # Try conversion with modified parameters
        conversion_command = (
            f"tensorflowjs_converter "
            f"--input_format=tf_saved_model "
            f"--output_format=tfjs_graph_model "
            f"--saved_model_tags=serve "
            f"--control_flow_v2=True "
            f"'{saved_model_dir}' "
            f"'{tfjs_target_dir}'"
        )

        print(f"Running conversion command: {conversion_command}")
        result = os.system(conversion_command)

        if result != 0:
            raise Exception("TensorFlow.js conversion failed")

        # Verify conversion
        if not os.path.exists(os.path.join(tfjs_target_dir, 'model.json')):
            raise Exception("TensorFlow.js conversion failed - model.json not found")

        print(f"Successfully converted model to TensorFlow.js format")
        print(f"Output saved to: {tfjs_target_dir}")

        # Print model files
        print("\nConverted model files:")
        for file in os.listdir(tfjs_target_dir):
            print(f"- {file}")

        # Create a zip file of the converted model
        shutil.make_archive(tfjs_target_dir, 'zip', tfjs_target_dir)

        # Download the zip file
        files.download("converted_model/tfjs_model.zip")

    except Exception as e:
        print(f"Error during conversion: {str(e)}")
        print("\nDebug information:")
        print(f"Current working directory: {os.getcwd()}")
        print(f"PT model exists: {os.path.exists(pt_model_path)}")
        if 'saved_model_dir' in locals():
            print(f"SavedModel directory exists: {os.path.exists(saved_model_dir)}")
            if os.path.exists(saved_model_dir):
                print("SavedModel contents:")
                for root, dirs, files in os.walk(saved_model_dir):
                    print(f"\nDirectory: {root}")
                    for f in files:
                        print(f"  - {f}")
        raise



# Upload your .pt model file
from google.colab import files
uploaded = files.upload()

#Get the filename of the uploaded file
pt_model_path = next(iter(uploaded.keys()))
output_dir = "converted_model"

# Convert the model
convert_to_tfjs(pt_model_path, output_dir)


Real-time Hand Pose Detection Web Application

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Real-time Hand Pose Detection</title>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
    <style>
        body { 
            text-align: center; 
            font-family: Arial, sans-serif;
            margin: 0;
            padding: 20px;
            background: #f0f0f0;
        }
        .container {
            position: relative;
            width: 640px;
            height: 480px;
            margin: 20px auto;
        }
        video, canvas { 
            position: absolute;
            left: 0;
            top: 0;
        }
        button {
            margin: 10px;
            padding: 10px 20px;
            font-size: 16px;
            cursor: pointer;
            background: #007bff;
            color: white;
            border: none;
            border-radius: 4px;
        }
        button:hover {
            background: #0056b3;
        }
        #status {
            padding: 10px;
            background: #fff;
            border-radius: 4px;
            display: inline-block;
        }
    </style>
</head>
<body>
    <h1>Real-time Hand Pose Detection (YOLOv8)</h1>
    <button onclick="loadModel()">Load Model</button>
    <button onclick="startWebcam()">Start Webcam</button>
    <p id="status">Model not loaded</p>

    <div class="container">
        <video id="video" width="640" height="480" autoplay></video>
        <canvas id="canvas" width="640" height="480"></canvas>
    </div>

    <script type="module">
        let model;
        let video = document.getElementById("video");
        let canvas = document.getElementById("canvas");
        let ctx = canvas.getContext("2d");

        const CONF_THRESHOLD = 0.7;
        const IOU_THRESHOLD = 0.45;
        let isProcessing = false;
        let previousDetections = [];

        // Model input size constants
        const MODEL_WIDTH = 640;
        const MODEL_HEIGHT = 640;
        const SCALE_FACTOR = 2.0; // Adjust this to make bbox larger

        async function loadModel() {
            try {
                document.getElementById("status").innerText = "Loading model...";
                model = await tf.loadGraphModel('http://localhost:8000/model.json');
                document.getElementById("status").innerText = "Model loaded!";
                console.log("Model loaded successfully");
            } catch (error) {
                console.error("Error loading model:", error);
                document.getElementById("status").innerText = "Error loading model!";
            }
        }

        async function startWebcam() {
            if (!model) {
                alert("Please load the model first!");
                return;
            }

            try {
                const stream = await navigator.mediaDevices.getUserMedia({ 
                    video: { 
                        width: { ideal: 640 },
                        height: { ideal: 480 },
                        facingMode: 'user'
                    } 
                });
                video.srcObject = stream;
                video.onloadedmetadata = () => {
                    video.play();
                    processVideoFrame();
                };
            } catch (err) {
                console.error("Error accessing webcam:", err);
                document.getElementById("status").innerText = "Error accessing webcam!";
            }
        }

        async function processVideoFrame() {
            if (!model || !video.videoWidth || isProcessing) return;
            
            try {
                isProcessing = true;
                
                // Create a square input for the model while maintaining aspect ratio
                const offscreenCanvas = document.createElement('canvas');
                offscreenCanvas.width = MODEL_WIDTH;
                offscreenCanvas.height = MODEL_HEIGHT;
                const offscreenCtx = offscreenCanvas.getContext('2d');
                
                // Calculate scaling to maintain aspect ratio
                const scale = Math.min(MODEL_WIDTH / video.videoWidth, MODEL_HEIGHT / video.videoHeight);
                const scaledWidth = video.videoWidth * scale;
                const scaledHeight = video.videoHeight * scale;
                const offsetX = (MODEL_WIDTH - scaledWidth) / 2;
                const offsetY = (MODEL_HEIGHT - scaledHeight) / 2;
                
                offscreenCtx.fillStyle = 'black';
                offscreenCtx.fillRect(0, 0, MODEL_WIDTH, MODEL_HEIGHT);
                offscreenCtx.drawImage(video, offsetX, offsetY, scaledWidth, scaledHeight);
                
                const imgTensor = tf.tidy(() => {
                    return tf.browser.fromPixels(offscreenCanvas)
                        .expandDims(0)
                        .toFloat()
                        .div(255.0);
                });
        
                const predictions = await model.predict(imgTensor);
                imgTensor.dispose();
                
                const processedDetections = await processDetections(predictions, {
                    offsetX,
                    offsetY,
                    scale,
                    originalWidth: video.videoWidth,
                    originalHeight: video.videoHeight
                });
                
                const smoothedDetections = smoothDetections(processedDetections);
                drawDetections(smoothedDetections);
                
                previousDetections = smoothedDetections;
                
                if (Array.isArray(predictions)) {
                    predictions.forEach(p => p.dispose());
                } else {
                    predictions.dispose();
                }
                
            } catch (error) {
                console.error("Error in processing frame:", error);
            } finally {
                isProcessing = false;
                requestAnimationFrame(processVideoFrame);
            }
        }

        async function processDetections(predictionTensor, transformInfo) {
            const predictions = await predictionTensor.array();
            
            if (!predictions.length || !predictions[0].length) {
                return [];
            }
            
            let detections = [];
            const numDetections = predictions[0][0].length;
            
            for (let i = 0; i < numDetections; i++) {
                const confidence = predictions[0][4][i];
                
                if (confidence > CONF_THRESHOLD) {
                    // Get raw coordinates from model output
                    let x = (predictions[0][0][i] - transformInfo.offsetX) / transformInfo.scale;
                    let y = (predictions[0][1][i] - transformInfo.offsetY) / transformInfo.scale;
                    let width = (predictions[0][2][i] / transformInfo.scale) * SCALE_FACTOR;
                    let height = (predictions[0][3][i] / transformInfo.scale) * SCALE_FACTOR;
                    
                    // Get keypoint (assuming wrist point)
                    let kp_x = (predictions[0][5][i] - transformInfo.offsetX) / transformInfo.scale;
                    let kp_y = (predictions[0][6][i] - transformInfo.offsetY) / transformInfo.scale;
                    
                    // Normalize coordinates
                    x = x / transformInfo.originalWidth;
                    y = y / transformInfo.originalHeight;
                    width = width / transformInfo.originalWidth;
                    height = height / transformInfo.originalHeight;
                    kp_x = kp_x / transformInfo.originalWidth;
                    kp_y = kp_y / transformInfo.originalHeight;
                    
                    // Ensure coordinates are within bounds
                    x = Math.max(0, Math.min(1, x));
                    y = Math.max(0, Math.min(1, y));
                    kp_x = Math.max(0, Math.min(1, kp_x));
                    kp_y = Math.max(0, Math.min(1, kp_y));
                    
                    detections.push({
                        bbox: [x, y, width, height],
                        confidence,
                        keypoint: [kp_x, kp_y]
                    });
                }
            }
            
            return applyNMS(detections);
        }

        function smoothDetections(currentDetections) {
            if (!previousDetections.length) return currentDetections;
            
            return currentDetections.map(detection => {
                const prevDetection = findClosestPreviousDetection(detection, previousDetections);
                if (prevDetection) {
                    const alpha = 0.7;
                    return {
                        bbox: detection.bbox.map((coord, i) => 
                            alpha * coord + (1 - alpha) * prevDetection.bbox[i]
                        ),
                        confidence: detection.confidence,
                        keypoint: detection.keypoint.map((coord, i) => 
                            alpha * coord + (1 - alpha) * prevDetection.keypoint[i]
                        )
                    };
                }
                return detection;
            });
        }

        function findClosestPreviousDetection(detection, previousDetections) {
            if (!previousDetections.length) return null;
            
            let minDist = Infinity;
            let closestDetection = null;
            
            previousDetections.forEach(prevDetection => {
                const dist = Math.sqrt(
                    Math.pow(detection.keypoint[0] - prevDetection.keypoint[0], 2) +
                    Math.pow(detection.keypoint[1] - prevDetection.keypoint[1], 2)
                );
                
                if (dist < minDist) {
                    minDist = dist;
                    closestDetection = prevDetection;
                }
            });
            
            return minDist < 0.3 ? closestDetection : null;
        }

        function calculateIoU(box1, box2) {
            const [x1, y1, w1, h1] = box1;
            const [x2, y2, w2, h2] = box2;
            
            const x1min = x1 - w1/2;
            const x1max = x1 + w1/2;
            const y1min = y1 - h1/2;
            const y1max = y1 + h1/2;
            
            const x2min = x2 - w2/2;
            const x2max = x2 + w2/2;
            const y2min = y2 - h2/2;
            const y2max = y2 + h2/2;
            
            const xOverlap = Math.max(0, Math.min(x1max, x2max) - Math.max(x1min, x2min));
            const yOverlap = Math.max(0, Math.min(y1max, y2max) - Math.max(y1min, y2min));
            
            const intersectionArea = xOverlap * yOverlap;
            const union = w1 * h1 + w2 * h2 - intersectionArea;
            
            return intersectionArea / union;
        }

        async function applyNMS(detections) {
            detections.sort((a, b) => b.confidence - a.confidence);
            
            const selected = [];
            const active = new Set(Array(detections.length).keys());
            
            for (let i = 0; i < detections.length; i++) {
                if (!active.has(i)) continue;
                
                selected.push(detections[i]);
                
                for (let j = i + 1; j < detections.length; j++) {
                    if (!active.has(j)) continue;
                    
                    const iou = calculateIoU(detections[i].bbox, detections[j].bbox);
                    if (iou >= IOU_THRESHOLD) active.delete(j);
                }
            }
            
            return selected;
        }

        function drawDetections(detections) {
            ctx.clearRect(0, 0, canvas.width, canvas.height);
            ctx.drawImage(video, 0, 0, canvas.width, canvas.height);
            
            detections.forEach(detection => {
                const [x, y, width, height] = detection.bbox;
                const [keypointX, keypointY] = detection.keypoint;
                
                // Convert normalized coordinates to pixel values
                const boxX = (x - width/2) * canvas.width;
                const boxY = (y - height/2) * canvas.height;
                const boxWidth = width * canvas.width;
                const boxHeight = height * canvas.height;
                
                // Draw bounding box
                ctx.strokeStyle = 'red';
                ctx.lineWidth = 2;
                ctx.strokeRect(boxX, boxY, boxWidth, boxHeight);
                
                // Draw keypoint
                const kpX = keypointX * canvas.width;
                const kpY = keypointY * canvas.height;
                
                ctx.fillStyle = 'blue';
                ctx.beginPath();
                ctx.arc(kpX, kpY, 5, 0, 2 * Math.PI);
                ctx.fill();
                
                // Draw confidence score
                ctx.fillStyle = 'red';
                ctx.font = '14px Arial';
                ctx.fillText(`Conf: ${detection.confidence.toFixed(2)}`, boxX, boxY - 5);

                // Draw lines from bbox center to keypoint
                ctx.beginPath();
                ctx.moveTo(boxX + boxWidth/2, boxY + boxHeight/2);
                ctx.lineTo(kpX, kpY);
                ctx.strokeStyle = 'green';
                ctx.stroke();
            });
        }

        window.loadModel = loadModel;
        window.startWebcam = startWebcam;
    </script>
</body>
</html>

Hand wrist detection

import os
import onnx
import time
import yaml
import torch
import numpy as np
from pathlib import Path
from ultralytics import YOLO

class HandWristDetector:
    def __init__(self, config_path='config.yaml'):
        """
        Initialize HandWristDetector with configuration
        
        Args:
            config_path (str): Path to the configuration YAML file
        """
        with open(config_path, 'r') as f:
            self.config = yaml.safe_load(f)
        
        # Initialize YOLO pose detection model
        model_size = self.config['model']['size']
        model_path = f"yolov8{model_size}-pose.pt"
        
        # Download model if not exists
        if not os.path.exists(model_path):
            print(f"Downloading YOLOv8{model_size} pose model...")
        
        self.model = YOLO(model_path)
        
    def train(self, data_yaml):
        """
        Train the model with custom configuration
        
        Args:
            data_yaml (str): Path to the data YAML file containing dataset configuration
            
        Returns:
            results: Training results object
        """
        # Set training arguments
        args = dict(
            data=data_yaml,                    
            task='pose',                       
            mode='train',                      
            model=self.model,                  
            epochs=self.config['model']['epochs'],
            imgsz=self.config['model']['image_size'],
            batch=self.config['model']['batch_size'],
            device='',                         
            workers=8,                         
            optimizer='AdamW',                  
            patience=20,                       
            verbose=True,                      
            seed=0,                           
            deterministic=True,                
            single_cls=True,                   
            rect=True,                         
            cos_lr=True,                       
            close_mosaic=10,                   
            resume=False,                      
            amp=True,                          
            
            # Learning rate settings
            lr0=0.001,                        
            lrf=0.01,                         
            momentum=0.937,                    
            weight_decay=0.0005,              
            warmup_epochs=3.0,                
            warmup_momentum=0.8,              
            warmup_bias_lr=0.1,               
            
            # Loss coefficients
            box=7.5,                          
            cls=0.5,                          
            pose=12.0,                        
            kobj=2.0,                         
            
            # Augmentation settings
            degrees=10.0,                      
            translate=0.2,                    
            scale=0.7,                        
            fliplr=0.5,                       
            mosaic=1.0,                       
            mixup=0.0,                        
            
            # Saving settings
            project='runs/pose',              
            name='train',                     
            exist_ok=False,                   
            pretrained=True,                  
            plots=True,                       
            save=True,                        
            save_period=-1,                   
            
            # Validation settings
            val=True,                         
            save_json=False,                  
            conf=None,                        
            iou=0.7,                          
            max_det=300,                      
            
            # Advanced settings
            fraction=1.0,                    
            profile=False,                    
            overlap_mask=True,                
            mask_ratio=4,                     
            dropout=0.2,                      
            label_smoothing=0.1,              
            nbs=64,                          
        )
        
        # Start training
        try:
            results = self.model.train(**args)
            return results
        except Exception as e:
            print(f"Training error: {str(e)}")
            raise
    
    def evaluate(self, data_yaml):
        """
        Evaluate the model on validation/test set
        
        Args:
            data_yaml (str): Path to the data YAML file
            
        Returns:
            results: Validation results object
        """
        try:
            results = self.model.val(
                data=data_yaml,
                imgsz=self.config['model']['image_size'],
                batch=self.config['model']['batch_size'],
                conf=0.25,
                iou=0.7,
                device='',
                verbose=True,
                save_json=False,
                save_hybrid=False,
                max_det=300,
                half=False
            )
            return results
        except Exception as e:
            print(f"Evaluation error: {str(e)}")
            raise
    
    def export_model(self, format='onnx'):
        """
        Export the model to specified format
        
        Args:
            format (str): Format to export to ('onnx' or 'tflite')
        """
        try:
            if format == 'onnx':
                self.model.export(
                    format='onnx',
                    dynamic=True,
                    simplify=True,
                    opset=11,
                    device='cpu'
                )
            elif format == 'tflite':
                self.model.export(
                    format='tflite',
                    int8=True,
                    device='cpu'
                )
        except Exception as e:
            print(f"Export error: {str(e)}")
            raise
    
    def predict(self, image_path):
        """
        Run inference on a single image
        
        Args:
            image_path (str): Path to the input image
            
        Returns:
            results: Detection results object
        """
        try:
            results = self.model.predict(
                source=image_path,
                conf=0.25,
                iou=0.45,
                imgsz=self.config['model']['image_size'],
                device='',
                verbose=False,
                save=True,
                save_txt=False,
                save_conf=False,
                save_crop=False,
                show_labels=True,
                show_conf=True,
                max_det=300,
                agnostic_nms=False,
                classes=None,
                retina_masks=False,
                boxes=True
            )
            return results[0]
        except Exception as e:
            print(f"Prediction error: {str(e)}")
            raise
    
    def predict_batch(self, image_paths):
        """
        Run inference on a batch of images
        
        Args:
            image_paths (list): List of paths to input images
            
        Returns:
            results: List of detection results objects
        """
        try:
            results = self.model.predict(
                source=image_paths,
                conf=0.25,
                iou=0.45,
                imgsz=self.config['model']['image_size'],
                batch=self.config['model']['batch_size']
            )
            return results
        except Exception as e:
            print(f"Batch prediction error: {str(e)}")
            raise

config.yaml

paths: 
  hand_img_dir: "/train/images"
  non_hand_dir: "/non-hands"        
  annotations_dir: "/train/labels"
  output_dir: "/Hand_wrist_keypoint"


model:
  size: "n"  
  epochs: 50  
  image_size: 224  
  batch_size: 16  
  pretrained: true 
  conf_thres: 0.25  
  iou_thres: 0.45  
  device: ""  

training:
  train_ratio: 0.7
  val_ratio: 0.15
  seed: 42

**
What Actually Happened:**

The model does detect things, but with significant issues:

The bounding boxes and keypoints appear, but not where they should be – they’re incorrectly positioned relative to my actual hand
Multiple overlapping detections occur for a single hand, suggesting NMS isn’t working properly
The model unexpectedly detects my face, even though it was trained only for hand detection
There’s no stability in the detections – they jitter and move erratically
While the model technically “works” (it produces outputs), the detections are so misaligned and unstable that they’re unusable

Why aren’t you exporting directly to tfjs using Ultralytics? Is this conversion code from GPT?

1 Like

You can also export with nms=True so that the model includes NMS.

1 Like