最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

javascript - Issue with TensorFlow.js Conversion – YOLOv8-Pose Not Detecting Hand & Wrist Keypoint - Stack Overflow

programmeradmin3浏览0评论

I’m looking for help with my machine learning model that detects my hand and one wrist keypoint. After training, the model correctly detects my hand with a bounding box and wrist keypoint in PyTorch. However, after converting the best.pt file to a TensorFlow.js model, the detection fails it no longer detects my hand or the keypoint.

Model Details

YOLOv8 trained for pose detection Custom dataset with hand images and wrist keypoint annotations Input size: 224x224 The model works correctly in PyTorch environment

Here is how i did my convertion

import os
from ultralytics import YOLO
import shutil
import tensorflow as tf
from google.colab import files

def find_saved_model(base_path):
    """Find the SavedModel directory in the export path"""
    for root, dirs, files in os.walk(base_path):
        if 'saved_model.pb' in files:
            return root
    return None

def add_signatures(saved_model_dir):
    """Load the SavedModel and add required signatures"""
    print("Adding signatures to SavedModel...")

    # Load the model
    model = tf.saved_model.load(saved_model_dir)

    # Create a wrapper function that matches the model's interface
    @tf.function(input_signature=[
        tf.TensorSpec(shape=[1, 640, 640, 3], dtype=tf.float32, name='images')
    ])
    def serving_fn(images):
        # Call model directly without training parameter
        return model(images)

    # Convert the model
    concrete_func = serving_fn.get_concrete_function()

    # Create a new SavedModel with the signature
    tf.saved_model.save(
        model,
        saved_model_dir,
        signatures={
            'serving_default': concrete_func
        }
    )

    print("Signatures added successfully")
    return saved_model_dir

def convert_to_tfjs(pt_model_path, output_dir):
    """
    Convert a PyTorch YOLO model to TensorFlow.js format

    Args:
        pt_model_path (str): Path to the .pt file
        output_dir (str): Directory to save the converted model
    """
    try:
        # Ensure output directory exists
        os.makedirs(output_dir, exist_ok=True)

        # Load the model
        print(f"Loading YOLO model from {pt_model_path}...")
        model = YOLO(pt_model_path)

        # First export to TensorFlow format
        print("Exporting to TensorFlow format...")

        # Export the model
        success = model.export(
            format='saved_model',
            imgsz=672,
            half=False,
            simplify=True
        )

        # Find the SavedModel directory
        saved_model_dir = find_saved_model(os.path.join(os.getcwd(), "best_saved_model"))
        if not saved_model_dir:
            raise Exception(f"Cannot find SavedModel directory in {os.path.dirname(pt_model_path)}")

        print(f"Found SavedModel at: {saved_model_dir}")

        # Add signatures to the model
        saved_model_dir = add_signatures(saved_model_dir)

        # Convert to TensorFlow.js
        print("Converting to TensorFlow.js format...")
        tfjs_target_dir = os.path.join(output_dir, 'tfjs_model')

        # Ensure clean target directory
        if os.path.exists(tfjs_target_dir):
            shutil.rmtree(tfjs_target_dir)
        os.makedirs(tfjs_target_dir)

        # Try conversion with modified parameters
        conversion_command = (
            f"tensorflowjs_converter "
            f"--input_format=tf_saved_model "
            f"--output_format=tfjs_graph_model "
            f"--saved_model_tags=serve "
            f"--control_flow_v2=True "
            f"'{saved_model_dir}' "
            f"'{tfjs_target_dir}'"
        )

        print(f"Running conversion command: {conversion_command}")
        result = os.system(conversion_command)

        if result != 0:
            raise Exception("TensorFlow.js conversion failed")

        # Verify conversion
        if not os.path.exists(os.path.join(tfjs_target_dir, 'model.json')):
            raise Exception("TensorFlow.js conversion failed - model.json not found")

        print(f"Successfully converted model to TensorFlow.js format")
        print(f"Output saved to: {tfjs_target_dir}")

        # Print model files
        print("\nConverted model files:")
        for file in os.listdir(tfjs_target_dir):
            print(f"- {file}")

        # Create a zip file of the converted model
        shutil.make_archive(tfjs_target_dir, 'zip', tfjs_target_dir)

        # Download the zip file
        files.download("converted_model/tfjs_model.zip")

    except Exception as e:
        print(f"Error during conversion: {str(e)}")
        print("\nDebug information:")
        print(f"Current working directory: {os.getcwd()}")
        print(f"PT model exists: {os.path.exists(pt_model_path)}")
        if 'saved_model_dir' in locals():
            print(f"SavedModel directory exists: {os.path.exists(saved_model_dir)}")
            if os.path.exists(saved_model_dir):
                print("SavedModel contents:")
                for root, dirs, files in os.walk(saved_model_dir):
                    print(f"\nDirectory: {root}")
                    for f in files:
                        print(f"  - {f}")
        raise



# Upload your .pt model file
from google.colab import files
uploaded = files.upload()

#Get the filename of the uploaded file
pt_model_path = next(iter(uploaded.keys()))
output_dir = "converted_model"

# Convert the model
convert_to_tfjs(pt_model_path, output_dir)

Real-time Hand Pose Detection Web Application

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Real-time Hand Pose Detection</title>
    <script src="/@tensorflow/tfjs"></script>
    <style>
        body { 
            text-align: center; 
            font-family: Arial, sans-serif;
            margin: 0;
            padding: 20px;
            background: #f0f0f0;
        }
        .container {
            position: relative;
            width: 640px;
            height: 480px;
            margin: 20px auto;
        }
        video, canvas { 
            position: absolute;
            left: 0;
            top: 0;
        }
        button {
            margin: 10px;
            padding: 10px 20px;
            font-size: 16px;
            cursor: pointer;
            background: #007bff;
            color: white;
            border: none;
            border-radius: 4px;
        }
        button:hover {
            background: #0056b3;
        }
        #status {
            padding: 10px;
            background: #fff;
            border-radius: 4px;
            display: inline-block;
        }
    </style>
</head>
<body>
    <h1>Real-time Hand Pose Detection (YOLOv8)</h1>
    <button onclick="loadModel()">Load Model</button>
    <button onclick="startWebcam()">Start Webcam</button>
    <p id="status">Model not loaded</p>

    <div class="container">
        <video id="video" width="640" height="480" autoplay></video>
        <canvas id="canvas" width="640" height="480"></canvas>
    </div>

    <script type="module">
        let model;
        let video = document.getElementById("video");
        let canvas = document.getElementById("canvas");
        let ctx = canvas.getContext("2d");

        const CONF_THRESHOLD = 0.7;
        const IOU_THRESHOLD = 0.45;
        let isProcessing = false;
        let previousDetections = [];

        // Model input size constants
        const MODEL_WIDTH = 640;
        const MODEL_HEIGHT = 640;
        const SCALE_FACTOR = 2.0; // Adjust this to make bbox larger

        async function loadModel() {
            try {
                document.getElementById("status").innerText = "Loading model...";
                model = await tf.loadGraphModel('http://localhost:8000/model.json');
                document.getElementById("status").innerText = "Model loaded!";
                console.log("Model loaded successfully");
            } catch (error) {
                console.error("Error loading model:", error);
                document.getElementById("status").innerText = "Error loading model!";
            }
        }

        async function startWebcam() {
            if (!model) {
                alert("Please load the model first!");
                return;
            }

            try {
                const stream = await navigator.mediaDevices.getUserMedia({ 
                    video: { 
                        width: { ideal: 640 },
                        height: { ideal: 480 },
                        facingMode: 'user'
                    } 
                });
                video.srcObject = stream;
                video.onloadedmetadata = () => {
                    video.play();
                    processVideoFrame();
                };
            } catch (err) {
                console.error("Error accessing webcam:", err);
                document.getElementById("status").innerText = "Error accessing webcam!";
            }
        }

        async function processVideoFrame() {
            if (!model || !video.videoWidth || isProcessing) return;
            
            try {
                isProcessing = true;
                
                // Create a square input for the model while maintaining aspect ratio
                const offscreenCanvas = document.createElement('canvas');
                offscreenCanvas.width = MODEL_WIDTH;
                offscreenCanvas.height = MODEL_HEIGHT;
                const offscreenCtx = offscreenCanvas.getContext('2d');
                
                // Calculate scaling to maintain aspect ratio
                const scale = Math.min(MODEL_WIDTH / video.videoWidth, MODEL_HEIGHT / video.videoHeight);
                const scaledWidth = video.videoWidth * scale;
                const scaledHeight = video.videoHeight * scale;
                const offsetX = (MODEL_WIDTH - scaledWidth) / 2;
                const offsetY = (MODEL_HEIGHT - scaledHeight) / 2;
                
                offscreenCtx.fillStyle = 'black';
                offscreenCtx.fillRect(0, 0, MODEL_WIDTH, MODEL_HEIGHT);
                offscreenCtx.drawImage(video, offsetX, offsetY, scaledWidth, scaledHeight);
                
                const imgTensor = tf.tidy(() => {
                    return tf.browser.fromPixels(offscreenCanvas)
                        .expandDims(0)
                        .toFloat()
                        .div(255.0);
                });
        
                const predictions = await model.predict(imgTensor);
                imgTensor.dispose();
                
                const processedDetections = await processDetections(predictions, {
                    offsetX,
                    offsetY,
                    scale,
                    originalWidth: video.videoWidth,
                    originalHeight: video.videoHeight
                });
                
                const smoothedDetections = smoothDetections(processedDetections);
                drawDetections(smoothedDetections);
                
                previousDetections = smoothedDetections;
                
                if (Array.isArray(predictions)) {
                    predictions.forEach(p => p.dispose());
                } else {
                    predictions.dispose();
                }
                
            } catch (error) {
                console.error("Error in processing frame:", error);
            } finally {
                isProcessing = false;
                requestAnimationFrame(processVideoFrame);
            }
        }

        async function processDetections(predictionTensor, transformInfo) {
            const predictions = await predictionTensor.array();
            
            if (!predictions.length || !predictions[0].length) {
                return [];
            }
            
            let detections = [];
            const numDetections = predictions[0][0].length;
            
            for (let i = 0; i < numDetections; i++) {
                const confidence = predictions[0][4][i];
                
                if (confidence > CONF_THRESHOLD) {
                    // Get raw coordinates from model output
                    let x = (predictions[0][0][i] - transformInfo.offsetX) / transformInfo.scale;
                    let y = (predictions[0][1][i] - transformInfo.offsetY) / transformInfo.scale;
                    let width = (predictions[0][2][i] / transformInfo.scale) * SCALE_FACTOR;
                    let height = (predictions[0][3][i] / transformInfo.scale) * SCALE_FACTOR;
                    
                    // Get keypoint (assuming wrist point)
                    let kp_x = (predictions[0][5][i] - transformInfo.offsetX) / transformInfo.scale;
                    let kp_y = (predictions[0][6][i] - transformInfo.offsetY) / transformInfo.scale;
                    
                    // Normalize coordinates
                    x = x / transformInfo.originalWidth;
                    y = y / transformInfo.originalHeight;
                    width = width / transformInfo.originalWidth;
                    height = height / transformInfo.originalHeight;
                    kp_x = kp_x / transformInfo.originalWidth;
                    kp_y = kp_y / transformInfo.originalHeight;
                    
                    // Ensure coordinates are within bounds
                    x = Math.max(0, Math.min(1, x));
                    y = Math.max(0, Math.min(1, y));
                    kp_x = Math.max(0, Math.min(1, kp_x));
                    kp_y = Math.max(0, Math.min(1, kp_y));
                    
                    detections.push({
                        bbox: [x, y, width, height],
                        confidence,
                        keypoint: [kp_x, kp_y]
                    });
                }
            }
            
            return applyNMS(detections);
        }

        function smoothDetections(currentDetections) {
            if (!previousDetections.length) return currentDetections;
            
            return currentDetections.map(detection => {
                const prevDetection = findClosestPreviousDetection(detection, previousDetections);
                if (prevDetection) {
                    const alpha = 0.7;
                    return {
                        bbox: detection.bbox.map((coord, i) => 
                            alpha * coord + (1 - alpha) * prevDetection.bbox[i]
                        ),
                        confidence: detection.confidence,
                        keypoint: detection.keypoint.map((coord, i) => 
                            alpha * coord + (1 - alpha) * prevDetection.keypoint[i]
                        )
                    };
                }
                return detection;
            });
        }

        function findClosestPreviousDetection(detection, previousDetections) {
            if (!previousDetections.length) return null;
            
            let minDist = Infinity;
            let closestDetection = null;
            
            previousDetections.forEach(prevDetection => {
                const dist = Math.sqrt(
                    Math.pow(detection.keypoint[0] - prevDetection.keypoint[0], 2) +
                    Math.pow(detection.keypoint[1] - prevDetection.keypoint[1], 2)
                );
                
                if (dist < minDist) {
                    minDist = dist;
                    closestDetection = prevDetection;
                }
            });
            
            return minDist < 0.3 ? closestDetection : null;
        }

        function calculateIoU(box1, box2) {
            const [x1, y1, w1, h1] = box1;
            const [x2, y2, w2, h2] = box2;
            
            const x1min = x1 - w1/2;
            const x1max = x1 + w1/2;
            const y1min = y1 - h1/2;
            const y1max = y1 + h1/2;
            
            const x2min = x2 - w2/2;
            const x2max = x2 + w2/2;
            const y2min = y2 - h2/2;
            const y2max = y2 + h2/2;
            
            const xOverlap = Math.max(0, Math.min(x1max, x2max) - Math.max(x1min, x2min));
            const yOverlap = Math.max(0, Math.min(y1max, y2max) - Math.max(y1min, y2min));
            
            const intersectionArea = xOverlap * yOverlap;
            const union = w1 * h1 + w2 * h2 - intersectionArea;
            
            return intersectionArea / union;
        }

        async function applyNMS(detections) {
            detections.sort((a, b) => b.confidence - a.confidence);
            
            const selected = [];
            const active = new Set(Array(detections.length).keys());
            
            for (let i = 0; i < detections.length; i++) {
                if (!active.has(i)) continue;
                
                selected.push(detections[i]);
                
                for (let j = i + 1; j < detections.length; j++) {
                    if (!active.has(j)) continue;
                    
                    const iou = calculateIoU(detections[i].bbox, detections[j].bbox);
                    if (iou >= IOU_THRESHOLD) active.delete(j);
                }
            }
            
            return selected;
        }

        function drawDetections(detections) {
            ctx.clearRect(0, 0, canvas.width, canvas.height);
            ctx.drawImage(video, 0, 0, canvas.width, canvas.height);
            
            detections.forEach(detection => {
                const [x, y, width, height] = detection.bbox;
                const [keypointX, keypointY] = detection.keypoint;
                
                // Convert normalized coordinates to pixel values
                const boxX = (x - width/2) * canvas.width;
                const boxY = (y - height/2) * canvas.height;
                const boxWidth = width * canvas.width;
                const boxHeight = height * canvas.height;
                
                // Draw bounding box
                ctx.strokeStyle = 'red';
                ctx.lineWidth = 2;
                ctx.strokeRect(boxX, boxY, boxWidth, boxHeight);
                
                // Draw keypoint
                const kpX = keypointX * canvas.width;
                const kpY = keypointY * canvas.height;
                
                ctx.fillStyle = 'blue';
                ctx.beginPath();
                ctx.arc(kpX, kpY, 5, 0, 2 * Math.PI);
                ctx.fill();
                
                // Draw confidence score
                ctx.fillStyle = 'red';
                ctx.font = '14px Arial';
                ctx.fillText(`Conf: ${detection.confidence.toFixed(2)}`, boxX, boxY - 5);

                // Draw lines from bbox center to keypoint
                ctx.beginPath();
                ctx.moveTo(boxX + boxWidth/2, boxY + boxHeight/2);
                ctx.lineTo(kpX, kpY);
                ctx.strokeStyle = 'green';
                ctx.stroke();
            });
        }

        window.loadModel = loadModel;
        window.startWebcam = startWebcam;
    </script>
</body>
</html>

Hand wrist detection

import os
import onnx
import time
import yaml
import torch
import numpy as np
from pathlib import Path
from ultralytics import YOLO

class HandWristDetector:
    def __init__(self, config_path='config.yaml'):
        """
        Initialize HandWristDetector with configuration
        
        Args:
            config_path (str): Path to the configuration YAML file
        """
        with open(config_path, 'r') as f:
            self.config = yaml.safe_load(f)
        
        # Initialize YOLO pose detection model
        model_size = self.config['model']['size']
        model_path = f"yolov8{model_size}-pose.pt"
        
        # Download model if not exists
        if not os.path.exists(model_path):
            print(f"Downloading YOLOv8{model_size} pose model...")
        
        self.model = YOLO(model_path)
        
    def train(self, data_yaml):
        """
        Train the model with custom configuration
        
        Args:
            data_yaml (str): Path to the data YAML file containing dataset configuration
            
        Returns:
            results: Training results object
        """
        # Set training arguments
        args = dict(
            data=data_yaml,                    
            task='pose',                       
            mode='train',                      
            model=self.model,                  
            epochs=self.config['model']['epochs'],
            imgsz=self.config['model']['image_size'],
            batch=self.config['model']['batch_size'],
            device='',                         
            workers=8,                         
            optimizer='AdamW',                  
            patience=20,                       
            verbose=True,                      
            seed=0,                           
            deterministic=True,                
            single_cls=True,                   
            rect=True,                         
            cos_lr=True,                       
            close_mosaic=10,                   
            resume=False,                      
            amp=True,                          
            
            # Learning rate settings
            lr0=0.001,                        
            lrf=0.01,                         
            momentum=0.937,                    
            weight_decay=0.0005,              
            warmup_epochs=3.0,                
            warmup_momentum=0.8,              
            warmup_bias_lr=0.1,               
            
            # Loss coefficients
            box=7.5,                          
            cls=0.5,                          
            pose=12.0,                        
            kobj=2.0,                         
            
            # Augmentation settings
            degrees=10.0,                      
            translate=0.2,                    
            scale=0.7,                        
            fliplr=0.5,                       
            mosaic=1.0,                       
            mixup=0.0,                        
            
            # Saving settings
            project='runs/pose',              
            name='train',                     
            exist_ok=False,                   
            pretrained=True,                  
            plots=True,                       
            save=True,                        
            save_period=-1,                   
            
            # Validation settings
            val=True,                         
            save_json=False,                  
            conf=None,                        
            iou=0.7,                          
            max_det=300,                      
            
            # Advanced settings
            fraction=1.0,                    
            profile=False,                    
            overlap_mask=True,                
            mask_ratio=4,                     
            dropout=0.2,                      
            label_smoothing=0.1,              
            nbs=64,                          
        )
        
        # Start training
        try:
            results = self.model.train(**args)
            return results
        except Exception as e:
            print(f"Training error: {str(e)}")
            raise
    
    def evaluate(self, data_yaml):
        """
        Evaluate the model on validation/test set
        
        Args:
            data_yaml (str): Path to the data YAML file
            
        Returns:
            results: Validation results object
        """
        try:
            results = self.model.val(
                data=data_yaml,
                imgsz=self.config['model']['image_size'],
                batch=self.config['model']['batch_size'],
                conf=0.25,
                iou=0.7,
                device='',
                verbose=True,
                save_json=False,
                save_hybrid=False,
                max_det=300,
                half=False
            )
            return results
        except Exception as e:
            print(f"Evaluation error: {str(e)}")
            raise
    
    def export_model(self, format='onnx'):
        """
        Export the model to specified format
        
        Args:
            format (str): Format to export to ('onnx' or 'tflite')
        """
        try:
            if format == 'onnx':
                self.model.export(
                    format='onnx',
                    dynamic=True,
                    simplify=True,
                    opset=11,
                    device='cpu'
                )
            elif format == 'tflite':
                self.model.export(
                    format='tflite',
                    int8=True,
                    device='cpu'
                )
        except Exception as e:
            print(f"Export error: {str(e)}")
            raise
    
    def predict(self, image_path):
        """
        Run inference on a single image
        
        Args:
            image_path (str): Path to the input image
            
        Returns:
            results: Detection results object
        """
        try:
            results = self.model.predict(
                source=image_path,
                conf=0.25,
                iou=0.45,
                imgsz=self.config['model']['image_size'],
                device='',
                verbose=False,
                save=True,
                save_txt=False,
                save_conf=False,
                save_crop=False,
                show_labels=True,
                show_conf=True,
                max_det=300,
                agnostic_nms=False,
                classes=None,
                retina_masks=False,
                boxes=True
            )
            return results[0]
        except Exception as e:
            print(f"Prediction error: {str(e)}")
            raise
    
    def predict_batch(self, image_paths):
        """
        Run inference on a batch of images
        
        Args:
            image_paths (list): List of paths to input images
            
        Returns:
            results: List of detection results objects
        """
        try:
            results = self.model.predict(
                source=image_paths,
                conf=0.25,
                iou=0.45,
                imgsz=self.config['model']['image_size'],
                batch=self.config['model']['batch_size']
            )
            return results
        except Exception as e:
            print(f"Batch prediction error: {str(e)}")
            raise

config.yaml

paths: 
  hand_img_dir: "/train/images"
  non_hand_dir: "/non-hands"        
  annotations_dir: "/train/labels"
  output_dir: "/Hand_wrist_keypoint"


model:
  size: "n"  
  epochs: 50  
  image_size: 224  
  batch_size: 16  
  pretrained: true 
  conf_thres: 0.25  
  iou_thres: 0.45  
  device: ""  

training:
  train_ratio: 0.7
  val_ratio: 0.15
  seed: 42

** What Actually Happened:**

The model does detect things, but with significant issues:

The bounding boxes and keypoints appear, but not where they should be – they're incorrectly positioned relative to my actual hand Multiple overlapping detections occur for a single hand, suggesting NMS isn't working properly The model unexpectedly detects my face, even though it was trained only for hand detection There's no stability in the detections – they jitter and move erratically While the model technically "works" (it produces outputs), the detections are so misaligned and unstable that they're unusable

与本文相关的文章

发布评论

评论列表(0)

  1. 暂无评论