diff --git a/apps/computer-vision/app.json b/apps/computer-vision/app.json index 4d68c039b..3c1a7067c 100644 --- a/apps/computer-vision/app.json +++ b/apps/computer-vision/app.json @@ -25,11 +25,34 @@ "foregroundImage": "./assets/icons/adaptive-icon.png", "backgroundColor": "#ffffff" }, - "package": "com.anonymous.computervision" + "package": "com.anonymous.computervision", + "permissions": [ + "CAMERA", + "READ_EXTERNAL_STORAGE", + "WRITE_EXTERNAL_STORAGE" + ] }, "web": { "favicon": "./assets/icons/favicon.png" }, - "plugins": ["expo-font", "expo-router"] + "plugins": [ + "expo-font", + "expo-router", + [ + "react-native-vision-camera", + { + "cameraPermissionText": "$(PRODUCT_NAME) needs access to your Camera to classify objects in real-time.", + "enableMicrophonePermission": false + } + ], + [ + "expo-build-properties", + { + "android": { + "minSdkVersion": 26 + } + } + ] + ] } } diff --git a/apps/computer-vision/app/_layout.tsx b/apps/computer-vision/app/_layout.tsx index 5914d2fe8..42f6584d8 100644 --- a/apps/computer-vision/app/_layout.tsx +++ b/apps/computer-vision/app/_layout.tsx @@ -42,7 +42,7 @@ export default function _layout() { }} > ( + drawerContent={(props: DrawerContentComponentProps) => ( )} screenOptions={{ @@ -108,6 +108,14 @@ export default function _layout() { headerTitleStyle: { color: ColorPalette.primary }, }} /> + { + const hash = label + .split('') + .reduce((acc, char) => acc + char.charCodeAt(0), 0); + return COLORS[hash % COLORS.length]; +}; + +export default function CameraObjectDetectionScreen() { + // Model loading state + const detectionModel = useMemo(() => new ObjectDetectionModule(), []); + const [isModelReady, setIsModelReady] = useState(false); + + // Resize plugin instance + const { resize } = useResizePlugin(); + + // Screen dimensions + const { width: screenWidth, height: screenHeight } = useWindowDimensions(); + + // Detection results + const [detections, setDetections] = useState([]); + const [fps, setFps] = useState(0); + const lastFrameTimeRef = useRef(Date.now()); + const [cameraLayout, setCameraLayout] = useState({ width: 0, height: 0 }); + + const processDetectionCallback = useRunOnJS((results: Detection[]) => { + setDetections(results); + const now = Date.now(); + const timeDiff = now - lastFrameTimeRef.current; + if (timeDiff > 0) { + setFps(Math.round(1000 / timeDiff)); + } + lastFrameTimeRef.current = now; + }, []); + + // Camera permissions + const device = useCameraDevice('back'); + const format = useCameraFormat(device, [ + { videoResolution: { width: 1280, height: 720 } }, + { videoStabilizationMode: 'off' }, + ]); + const { hasPermission, requestPermission } = useCameraPermission(); + const internalModel = isModelReady ? detectionModel.nativeModule : null; + + // Load the model + useEffect(() => { + (async () => { + try { + await detectionModel.load(SSDLITE_320_MOBILENET_V3_LARGE); + setIsModelReady(true); + } catch (error) { + console.error('Failed to load model:', error); + } + })(); + + return () => { + detectionModel.delete(); + }; + }, [detectionModel]); + + // Frame processor throttled to 10fps via frameProcessorFps prop + const frameProcessor = useFrameProcessor( + (frame) => { + 'worklet'; + + try { + if (internalModel == null) { + return; + } + + // Resize frame to model input size (640x640 for YOLO) + const resizedArray = resize(frame, { + scale: { + width: 640, + height: 640, + }, + pixelFormat: 'rgb', + dataType: 'uint8', + rotation: '90deg', + }); + + // Create object with buffer and dimensions + const resizedData = { + data: resizedArray.buffer, + width: 640, + height: 640, + }; + + // Pass raw pixel data to model with detection threshold + const result = internalModel.generateFromFrame(resizedData, 0.5); + + // Pass results and timestamp to JS + processDetectionCallback(result); + } catch (error: any) { + console.log( + 'Frame processing error:', + error?.message || 'Unknown error' + ); + } + }, + [internalModel, resize, processDetectionCallback] + ); + + // Loading state + if (!isModelReady) { + return ( + + + + Loading model... + + + ); + } + + // Request permissions + if (!hasPermission) { + return ( + + + + Camera permission is required + + + Grant Permission + + + + ); + } + + // No camera device + if (!device) { + return ( + + + No camera device found + + + ); + } + + return ( + + + {/* Camera View */} + { + const { width, height } = event.nativeEvent.layout; + setCameraLayout({ width, height }); + }} + /> + + {/* Bounding Box Overlay with Skia */} + + + {/* Stats Overlay */} + + + FPS + {fps} + + + Objects + {detections.length} + + + + {/* Detection Info */} + {detections.length > 0 && ( + + {detections.slice(0, 5).map((detection, index) => ( + + + {detection.label.toLowerCase().replace(/_/g, ' ')} + + + {(detection.score * 100).toFixed(0)}% + + + ))} + + )} + + + ); +} + +// Bounding Box Overlay Component using Skia for smooth rendering +function BoundingBoxOverlay({ + detections, + screenWidth, + screenHeight, + cameraLayout, +}: { + detections: Detection[]; + screenWidth: number; + screenHeight: number; + cameraLayout: { width: number; height: number }; +}) { + const font = matchFont({ + fontSize: 14, + fontWeight: 'bold', + }); + + // Frame size is 640x640 because that's what we're passing to C++ as originalSize + // The C++ postprocess scales bboxes relative to this size + const frameSize = { width: 640, height: 640 }; + + // Use actual camera layout if available, otherwise fallback to screen dimensions + const cameraWidth = cameraLayout.width || screenWidth; + const cameraHeight = cameraLayout.height || screenHeight; + + // Calculate how the camera frame fits on screen (matching your working example) + const frameAspectRatio = frameSize.width / frameSize.height; + const cameraAspectRatio = cameraWidth / cameraHeight; + + let previewWidth, previewHeight, offsetX, offsetY; + + if (cameraAspectRatio > frameAspectRatio) { + // Screen is wider - pillarboxed + previewHeight = cameraHeight; + previewWidth = cameraHeight * frameAspectRatio; + offsetX = (cameraWidth - previewWidth) / 2; + offsetY = 0; + } else { + // Screen is taller - letterboxed + previewWidth = cameraWidth; + previewHeight = cameraWidth / frameAspectRatio; + offsetX = 0; + offsetY = (cameraHeight - previewHeight) / 2; + } + + const scaleX = previewWidth / frameSize.width; + const scaleY = previewHeight / frameSize.height; + + return ( + + {detections.map((detection, index) => { + const { bbox, label, score } = detection; + const color = getColorForLabel(label); + + // Direct transformation - rotation is handled by resize plugin + const x = bbox.x1 * scaleX + offsetX; + const y = bbox.y1 * scaleY + offsetY; + const width = (bbox.x2 - bbox.x1) * scaleX; + const height = (bbox.y2 - bbox.y1) * scaleY; + + // Label dimensions + const labelText = `${label.toLowerCase().replace(/_/g, ' ')} ${(score * 100).toFixed(0)}%`; + const labelWidth = Math.max(width, 120); + const labelHeight = 24; + + return ( + + {/* Bounding box border */} + + + {/* Label background */} + + + {/* Label text */} + + + ); + })} + + ); +} + +const styles = StyleSheet.create({ + container: { + flex: 1, + }, + statsContainer: { + position: 'absolute', + top: 20, + right: 20, + flexDirection: 'row', + gap: 12, + }, + statBox: { + backgroundColor: 'rgba(0, 0, 0, 0.7)', + paddingHorizontal: 16, + paddingVertical: 8, + borderRadius: 12, + alignItems: 'center', + minWidth: 70, + }, + statLabel: { + color: '#888', + fontSize: 11, + fontWeight: '600', + textTransform: 'uppercase', + }, + statValue: { + color: '#fff', + fontSize: 24, + fontWeight: 'bold', + marginTop: 2, + }, + detectionList: { + position: 'absolute', + bottom: 20, + left: 20, + right: 20, + gap: 8, + }, + detectionItem: { + backgroundColor: 'rgba(0, 0, 0, 0.8)', + paddingHorizontal: 16, + paddingVertical: 12, + borderRadius: 12, + flexDirection: 'row', + justifyContent: 'space-between', + alignItems: 'center', + borderLeftWidth: 4, + }, + detectionLabel: { + color: 'white', + fontSize: 16, + fontWeight: '600', + textTransform: 'capitalize', + }, + detectionScore: { + color: '#4ECDC4', + fontSize: 16, + fontWeight: 'bold', + }, + loadingContainer: { + flex: 1, + justifyContent: 'center', + alignItems: 'center', + }, + loadingText: { + marginTop: 16, + fontSize: 16, + color: ColorPalette.strongPrimary, + }, + errorContainer: { + flex: 1, + justifyContent: 'center', + alignItems: 'center', + padding: 20, + }, + errorText: { + fontSize: 16, + color: '#d32f2f', + textAlign: 'center', + }, + permissionContainer: { + flex: 1, + justifyContent: 'center', + alignItems: 'center', + padding: 20, + }, + permissionText: { + fontSize: 18, + color: ColorPalette.strongPrimary, + marginBottom: 20, + textAlign: 'center', + }, + permissionButton: { + backgroundColor: ColorPalette.strongPrimary, + paddingHorizontal: 24, + paddingVertical: 12, + borderRadius: 8, + }, + permissionButtonText: { + color: 'white', + fontSize: 16, + fontWeight: 'bold', + }, +}); diff --git a/apps/computer-vision/app/index.tsx b/apps/computer-vision/app/index.tsx index 38a77fc27..1372e2062 100644 --- a/apps/computer-vision/app/index.tsx +++ b/apps/computer-vision/app/index.tsx @@ -53,6 +53,12 @@ export default function Home() { > Image Generation + router.navigate('camera_object_detection/')} + > + Camera Object Detection (Live) + ); @@ -92,6 +98,9 @@ const styles = StyleSheet.create({ alignItems: 'center', marginBottom: 10, }, + cameraButton: { + backgroundColor: '#2563eb', + }, buttonText: { color: 'white', fontSize: fontSizes.md, diff --git a/apps/computer-vision/package.json b/apps/computer-vision/package.json index 63885109a..6603a168e 100644 --- a/apps/computer-vision/package.json +++ b/apps/computer-vision/package.json @@ -16,6 +16,7 @@ "@react-navigation/native": "^7.1.6", "@shopify/react-native-skia": "2.2.12", "expo": "^54.0.27", + "expo-build-properties": "~1.0.10", "expo-constants": "~18.0.11", "expo-font": "~14.0.10", "expo-linking": "~8.0.10", @@ -34,7 +35,10 @@ "react-native-screens": "~4.16.0", "react-native-svg": "15.12.1", "react-native-svg-transformer": "^1.5.0", - "react-native-worklets": "0.5.1" + "react-native-vision-camera": "^4.7.3", + "react-native-worklets": "0.5.1", + "react-native-worklets-core": "^1.3.3", + "vision-camera-resize-plugin": "^3.2.0" }, "devDependencies": { "@babel/core": "^7.25.2", diff --git a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp index 7a4426e06..b66576d29 100644 --- a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp @@ -10,9 +10,9 @@ #include #include #include -#include #include #include +#include #include #include #include diff --git a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.h b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.h index d5c98763d..ff6c7fdf1 100644 --- a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.h +++ b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.h @@ -54,8 +54,15 @@ class RnExecutorchInstaller { meta::createConstructorArgsWithCallInvoker( args, runtime, jsCallInvoker); - auto modelImplementationPtr = std::make_shared( - std::make_from_tuple(constructorArgs)); + // Use std::apply to directly pass tuple arguments to make_shared + // This avoids creating a temporary and trying to copy it (which + // fails for non-copyable types like VisionModel) + auto modelImplementationPtr = std::apply( + [](auto &&...args) { + return std::make_shared( + std::forward(args)...); + }, + std::move(constructorArgs)); auto modelHostObject = std::make_shared>( modelImplementationPtr, jsCallInvoker); diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h index c8232fe8c..5a636424d 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -17,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -154,10 +156,17 @@ template class ModelHostObject : public JsiHostObject { ModelHostObject, promiseHostFunction<&Model::setFixedModel>, "setFixedModel")); } + + // Register generateFromFrame for all VisionModel subclasses + if constexpr (meta::DerivedFromOrSameAs) { + addFunctions(JSI_EXPORT_FUNCTION( + ModelHostObject, visionHostFunction<&Model::generateFromFrame>, + "generateFromFrame")); + } } - // A generic host function that runs synchronously, works analogously to the - // generic promise host function. + // A generic host function that runs synchronously, works analogously to + // the generic promise host function. template JSI_HOST_FUNCTION(synchronousHostFunction) { constexpr std::size_t functionArgCount = meta::getArgumentCount(FnPtr); if (functionArgCount != count) { @@ -203,9 +212,70 @@ template class ModelHostObject : public JsiHostObject { } } + template JSI_HOST_FUNCTION(visionHostFunction) { + // 1. Check Argument Count + // (We rely on our new FunctionTraits) + constexpr std::size_t cppArgCount = + meta::FunctionTraits::arity; + + // We expect JS args = (Total C++ Args) - (2 injected args: Runtime + Value) + constexpr std::size_t expectedJsArgs = cppArgCount - 1; + log(LOG_LEVEL::Debug, cppArgCount, count); + if (count != expectedJsArgs) { + throw jsi::JSError(runtime, "Argument count mismatch in vision function"); + } + + try { + // 2. The Magic Trick 🪄 + // We get a pointer to a dummy function: void dummy(Rest...) {} + // This function has exactly the signature of the arguments we want to + // parse. + auto dummyFuncPtr = &meta::TailSignature::dummy; + + // 3. Let existing helpers do the work + // We pass the dummy pointer. The helper inspects its arguments (Rest...) + // and converts args[0]...args[N] accordingly. + // Note: We pass (args + 1) because JS args[0] is the PixelData, which we + // handle manually. Note: We use expectedJsArgs - 1 because we skipped one + // JS arg. + auto tailArgsTuple = + meta::createArgsTupleFromJsi(dummyFuncPtr, args + 1, runtime); + + // 4. Invoke + using ReturnType = + typename meta::FunctionTraits::return_type; + + if constexpr (std::is_void_v) { + std::apply( + [&](auto &&...tailArgs) { + (model.get()->*FnPtr)( + runtime, + args[0], // 1. PixelData (Manually passed) + std::forward( + tailArgs)...); // 2. The rest (Auto parsed) + }, + std::move(tailArgsTuple)); + return jsi::Value::undefined(); + } else { + auto result = std::apply( + [&](auto &&...tailArgs) { + return (model.get()->*FnPtr)( + runtime, args[0], + std::forward(tailArgs)...); + }, + std::move(tailArgsTuple)); + + return jsi_conversion::getJsiValue(std::move(result), runtime); + } + } catch (const std::exception &e) { + throw jsi::JSError(runtime, e.what()); + } + } + // A generic host function that resolves a promise with a result of a - // function. JSI arguments are converted to the types provided in the function - // signature, and the return value is converted back to JSI before resolving. + // function. JSI arguments are converted to the types provided in the + // function signature, and the return value is converted back to JSI + // before resolving. template JSI_HOST_FUNCTION(promiseHostFunction) { auto promise = Promise::createPromise( runtime, callInvoker, @@ -226,8 +296,8 @@ template class ModelHostObject : public JsiHostObject { meta::createArgsTupleFromJsi(FnPtr, args, runtime); // We need to dispatch a thread if we want the function to be - // asynchronous. In this thread all accesses to jsi::Runtime need to - // be done via the callInvoker. + // asynchronous. In this thread all accesses to jsi::Runtime + // need to be done via the callInvoker. threads::GlobalThreadPool::detach([this, promise, argsConverted = std::move(argsConverted)]() { @@ -235,16 +305,16 @@ template class ModelHostObject : public JsiHostObject { if constexpr (std::is_void_v) { - // For void functions, just call the function and resolve - // with undefined + // For void functions, just call the function and + // resolve with undefined std::apply(std::bind_front(FnPtr, model), std::move(argsConverted)); callInvoker->invokeAsync([promise](jsi::Runtime &runtime) { promise->resolve(jsi::Value::undefined()); }); } else { - // For non-void functions, capture the result and convert - // it + // For non-void functions, capture the result and + // convert it auto result = std::apply(std::bind_front(FnPtr, model), std::move(argsConverted)); // The result is copied. It should either be quickly @@ -272,8 +342,8 @@ template class ModelHostObject : public JsiHostObject { // This catch should be merged with the next two // (std::runtime_error and jsi::JSError inherits from // std::exception) HOWEVER react native has broken RTTI - // which breaks proper exception type checking. Remove when - // the following change is present in our version: + // which breaks proper exception type checking. Remove + // when the following change is present in our version: // https://github.com/facebook/react-native/commit/3132cc88dd46f95898a756456bebeeb6c248f20e callInvoker->invokeAsync([e = std::move(e), promise]() { promise->reject(std::string(e.what())); @@ -334,5 +404,4 @@ template class ModelHostObject : public JsiHostObject { std::shared_ptr model; std::shared_ptr callInvoker; }; - } // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/FunctionHelpers.h b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/FunctionHelpers.h index 8290a810b..e3afd6f71 100644 --- a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/FunctionHelpers.h +++ b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/FunctionHelpers.h @@ -47,4 +47,12 @@ std::tuple createArgsTupleFromJsi(R (Model::*f)(Types...) const, return fillTupleFromArgs(std::index_sequence_for{}, args, runtime); } + +template +std::tuple createArgsTupleFromJsi(R (*f)(Types...), + const jsi::Value *args, + jsi::Runtime &runtime) { + return fillTupleFromArgs(std::index_sequence_for{}, args, + runtime); +} } // namespace rnexecutorch::meta \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h index 85a3db449..eeba00b4e 100644 --- a/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h +++ b/packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h @@ -4,7 +4,6 @@ #include namespace rnexecutorch::meta { - template concept DerivedFromOrSameAs = std::is_base_of_v; @@ -34,4 +33,40 @@ concept ProvidesMemoryLowerBound = requires(T t) { { &T::getMemoryLowerBound }; }; +// --------------------------------------------------------- +// FunctionTraits +// --------------------------------------------------------- +template struct FunctionTraits; + +// 1. Specialization for Member Function Pointers (You already had this) +template +struct FunctionTraits { + static constexpr std::size_t arity = sizeof...(Args); + using return_type = R; + template struct arg { + using type = typename std::tuple_element>::type; + }; +}; + +// 2. ✅ NEW: Specialization for Free/Static Function Pointers +// (Required for TailSignature::dummy) +template struct FunctionTraits { + static constexpr std::size_t arity = sizeof...(Args); + using return_type = R; + template struct arg { + using type = typename std::tuple_element>::type; + }; +}; + +// --------------------------------------------------------- +// TailSignature Helper +// --------------------------------------------------------- +template struct TailSignature; + +template +struct TailSignature { + // A dummy function that takes only the "Rest" arguments + static void dummy(Rest...) {} +}; } // namespace rnexecutorch::meta diff --git a/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h new file mode 100644 index 000000000..9053e034f --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h @@ -0,0 +1,112 @@ +#pragma once + +#include +#include +#include +#include + +namespace rnexecutorch { +namespace models { + +/** + * @brief Base class for computer vision models that support real-time camera + * input + * + * VisionModel extends BaseModel with thread-safe inference and automatic frame + * extraction from VisionCamera. This class is designed for models that need to + * process camera frames in real-time (e.g., at 30fps). + * + * Thread Safety: + * - All inference operations are protected by a mutex + * - generateFromFrame() uses try_lock() to skip frames when the model is busy + * - This prevents blocking the camera thread and maintains smooth frame rates + * + * Usage: + * Subclasses should: + * 1. Inherit from VisionModel instead of BaseModel + * 2. Implement preprocessFrame() with model-specific preprocessing + * 3. Use inference_mutex_ when calling forward() in custom generate methods + * 4. Use lock_guard for blocking operations (JS API) + * 5. Use try_lock() for non-blocking operations (camera API) + * + * Example: + * @code + * class Classification : public VisionModel { + * public: + * std::unordered_map + * generateFromFrame(jsi::Runtime& runtime, const jsi::Value& frameValue) { + * // try_lock is handled automatically + * auto frameObject = frameValue.asObject(runtime); + * cv::Mat frame = FrameExtractor::extractFrame(runtime, frameObject); + * + * // Lock before inference + * if (!inference_mutex_.try_lock()) { + * return {}; // Skip frame if busy + * } + * std::lock_guard lock(inference_mutex_, std::adopt_lock); + * + * auto preprocessed = preprocessFrame(frame); + * // ... run inference + * } + * }; + * @endcode + */ +class VisionModel : public BaseModel { +public: + /** + * @brief Inherit constructors from BaseModel + * + * VisionModel uses the same construction pattern as BaseModel, just adding + * thread-safety on top. + */ + using BaseModel::BaseModel; + + /** + * @brief Virtual destructor for proper cleanup in derived classes + */ + virtual ~VisionModel() = default; + +protected: + /** + * @brief Mutex to ensure thread-safe inference + * + * This mutex protects against race conditions when: + * - generateFromFrame() is called from VisionCamera worklet thread (30fps) + * - generate() is called from JavaScript thread simultaneously + * + * Usage guidelines: + * - Use std::lock_guard for blocking operations (JS API can wait) + * - Use try_lock() for non-blocking operations (camera should skip frames) + * + * @note Marked mutable to allow locking in const methods if needed + */ + mutable std::mutex inference_mutex_; + + /** + * @brief Preprocess a camera frame for model input + * + * This method should implement model-specific preprocessing such as: + * - Resizing to the model's expected input size + * - Color space conversion (e.g., BGR to RGB) + * - Normalization + * - Any other model-specific transformations + * + * @param frame Input frame from camera (already extracted and rotated by + * FrameExtractor) + * @return Preprocessed cv::Mat ready for tensor conversion + * + * @note The input frame is already in RGB format and rotated 90° clockwise + * @note This method is called under mutex protection in generateFromFrame() + */ + virtual cv::Mat preprocessFrame(const cv::Mat &frame) const = 0; +}; + +} // namespace models + +// Register VisionModel constructor traits +// Even though VisionModel is abstract, the metaprogramming system needs to know +// its constructor signature for derived classes +REGISTER_CONSTRUCTOR(models::VisionModel, std::string, + std::shared_ptr); + +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp index 0fba07108..216008033 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp @@ -1,6 +1,7 @@ #include "Classification.h" #include +#include #include #include @@ -12,7 +13,7 @@ namespace rnexecutorch::models::classification { Classification::Classification(const std::string &modelSource, std::shared_ptr callInvoker) - : BaseModel(modelSource, callInvoker) { + : VisionModel(modelSource, callInvoker) { auto inputShapes = getAllInputShapes(); if (inputShapes.size() == 0) { throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, @@ -34,6 +35,9 @@ Classification::Classification(const std::string &modelSource, std::unordered_map Classification::generate(std::string imageSource) { + // Lock and wait - JS API can afford to block + std::lock_guard lock(inference_mutex_); + auto inputTensor = image_processing::readImageToTensor(imageSource, getAllInputShapes()[0]) .first; @@ -46,6 +50,54 @@ Classification::generate(std::string imageSource) { return postprocess(forwardResult->at(0).toTensor()); } +std::unordered_map Classification::generateFromFrame( + jsi::Runtime &runtime, const jsi::Value &pixelData, int width, int height) { + // Try-lock: skip frame if model is busy (non-blocking for camera) + if (!inference_mutex_.try_lock()) { + return {}; // Return empty map, don't block camera thread + } + std::lock_guard lock(inference_mutex_, std::adopt_lock); + + // Get ArrayBuffer from JSI + auto arrayBuffer = pixelData.asObject(runtime).getArrayBuffer(runtime); + uint8_t *data = arrayBuffer.data(runtime); + size_t size = arrayBuffer.size(runtime); + + // Create cv::Mat from raw RGB data (no copy, just wraps the data) + cv::Mat frameImage(height, width, CV_8UC3, data); + + // Preprocess frame (resize and color convert if needed) + cv::Mat preprocessed = preprocessFrame(frameImage); + + // Create tensor and run inference + const std::vector tensorDims = getAllInputShapes()[0]; + auto inputTensor = + image_processing::getTensorFromMatrix(tensorDims, preprocessed); + + auto forwardResult = BaseModel::forward(inputTensor); + + if (!forwardResult.ok()) { + throw RnExecutorchError(forwardResult.error(), + "The model's forward function did not succeed. " + "Ensure the model input is correct."); + } + return postprocess(forwardResult->at(0).toTensor()); +} + +cv::Mat Classification::preprocessFrame(const cv::Mat &frame) const { + // Get target size from model input shape + const std::vector tensorDims = getAllInputShapes()[0]; + cv::Size tensorSize = cv::Size(tensorDims[tensorDims.size() - 1], + tensorDims[tensorDims.size() - 2]); + + // Resize and convert color + cv::Mat processed; + cv::resize(frame, processed, tensorSize); + cv::cvtColor(processed, processed, cv::COLOR_BGR2RGB); + + return processed; +} + std::unordered_map Classification::postprocess(const Tensor &tensor) { std::span resultData( diff --git a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.h b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.h index 1465fc5f9..db46c6417 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.h @@ -6,20 +6,27 @@ #include #include "rnexecutorch/metaprogramming/ConstructorHelpers.h" -#include +#include namespace rnexecutorch { namespace models::classification { using executorch::aten::Tensor; using executorch::extension::TensorPtr; -class Classification : public BaseModel { +class Classification : public VisionModel { public: Classification(const std::string &modelSource, std::shared_ptr callInvoker); [[nodiscard("Registered non-void function")]] std::unordered_map< std::string_view, float> generate(std::string imageSource); + [[nodiscard("Registered non-void function")]] std::unordered_map< + std::string_view, float> + generateFromFrame(jsi::Runtime &runtime, const jsi::Value &pixelData, + int width, int height); + +protected: + cv::Mat preprocessFrame(const cv::Mat &frame) const override; private: std::unordered_map postprocess(const Tensor &tensor); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp index 3bb1f9dea..792d03303 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp @@ -9,7 +9,7 @@ namespace rnexecutorch::models::object_detection { ObjectDetection::ObjectDetection( const std::string &modelSource, std::shared_ptr callInvoker) - : BaseModel(modelSource, callInvoker) { + : VisionModel(modelSource, callInvoker) { auto inputTensors = getAllInputShapes(); if (inputTensors.size() == 0) { throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs, @@ -71,6 +71,8 @@ ObjectDetection::postprocess(const std::vector &tensors, std::vector ObjectDetection::generate(std::string imageSource, double detectionThreshold) { + std::lock_guard lock(inference_mutex_); + auto [inputTensor, originalSize] = image_processing::readImageToTensor(imageSource, getAllInputShapes()[0]); @@ -83,4 +85,62 @@ ObjectDetection::generate(std::string imageSource, double detectionThreshold) { return postprocess(forwardResult.get(), originalSize, detectionThreshold); } + +std::vector +ObjectDetection::generateFromFrame(jsi::Runtime &runtime, + const jsi::Value &pixelData, + double detectionThreshold) { + // Try-lock: skip frame if model is busy (non-blocking for camera) + if (!inference_mutex_.try_lock()) { + return {}; // Return empty vector, don't block camera thread + } + std::lock_guard lock(inference_mutex_, std::adopt_lock); + + // Get ArrayBuffer from JSI + auto frameObj = pixelData.asObject(runtime); + auto frameData = frameObj.getProperty(runtime, "data"); + int width = + static_cast(frameObj.getProperty(runtime, "width").asNumber()); + int height = + static_cast(frameObj.getProperty(runtime, "height").asNumber()); + + auto arrayBuffer = frameData.asObject(runtime).getArrayBuffer(runtime); + uint8_t *data = arrayBuffer.data(runtime); + + // Create cv::Mat from raw RGB data (no copy, just wraps the data) + cv::Mat frameImage(height, width, CV_8UC3, data); + cv::Size originalSize(width, height); + + // Preprocess frame (resize and color convert) + cv::Mat preprocessed = preprocessFrame(frameImage); + + // Create tensor and run inference + const std::vector tensorDims = getAllInputShapes()[0]; + auto inputTensor = + image_processing::getTensorFromMatrix(tensorDims, preprocessed); + + auto forwardResult = BaseModel::forward(inputTensor); + + if (!forwardResult.ok()) { + throw RnExecutorchError(forwardResult.error(), + "The model's forward function did not succeed. " + "Ensure the model input is correct."); + } + + return postprocess(forwardResult.get(), originalSize, detectionThreshold); +} + +cv::Mat ObjectDetection::preprocessFrame(const cv::Mat &frame) const { + // Get target size from model input shape + const std::vector tensorDims = getAllInputShapes()[0]; + cv::Size tensorSize = cv::Size(tensorDims[tensorDims.size() - 1], + tensorDims[tensorDims.size() - 2]); + + // Resize and convert color + cv::Mat processed; + cv::resize(frame, processed, tensorSize); + cv::cvtColor(processed, processed, cv::COLOR_BGR2RGB); + + return processed; +} } // namespace rnexecutorch::models::object_detection diff --git a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h index bba09a6d8..d2d328035 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.h @@ -8,7 +8,7 @@ #include "Types.h" #include "rnexecutorch/metaprogramming/ConstructorHelpers.h" -#include +#include #include namespace rnexecutorch { @@ -16,12 +16,18 @@ namespace models::object_detection { using executorch::extension::TensorPtr; using executorch::runtime::EValue; -class ObjectDetection : public BaseModel { +class ObjectDetection : public VisionModel { public: ObjectDetection(const std::string &modelSource, std::shared_ptr callInvoker); [[nodiscard("Registered non-void function")]] std::vector generate(std::string imageSource, double detectionThreshold); + [[nodiscard("Registered non-void function")]] std::vector + generateFromFrame(jsi::Runtime &runtime, const jsi::Value &pixelData, + double detectionThreshold); + +protected: + cv::Mat preprocessFrame(const cv::Mat &frame) const override; private: std::vector postprocess(const std::vector &tensors, diff --git a/packages/react-native-executorch/react-native-executorch.podspec b/packages/react-native-executorch/react-native-executorch.podspec index e7c72b92c..89a5645c5 100644 --- a/packages/react-native-executorch/react-native-executorch.podspec +++ b/packages/react-native-executorch/react-native-executorch.podspec @@ -70,6 +70,7 @@ Pod::Spec.new do |s| # #include "Header.h" we get a conflict. Here, headers in jsi/ collide with # react-native-skia. The headers are preserved by preserve_paths and # then made available by HEADER_SEARCH_PATHS. + s.exclude_files = [ "common/rnexecutorch/tests/*.{cpp}", "common/rnexecutorch/jsi/*.{h,hpp}" diff --git a/yarn.lock b/yarn.lock index 1ca8d5d29..fb5e8a1c7 100644 --- a/yarn.lock +++ b/yarn.lock @@ -4961,6 +4961,18 @@ __metadata: languageName: node linkType: hard +"ajv@npm:^8.11.0": + version: 8.17.1 + resolution: "ajv@npm:8.17.1" + dependencies: + fast-deep-equal: "npm:^3.1.3" + fast-uri: "npm:^3.0.1" + json-schema-traverse: "npm:^1.0.0" + require-from-string: "npm:^2.0.2" + checksum: 10/ee3c62162c953e91986c838f004132b6a253d700f1e51253b99791e2dbfdb39161bc950ebdc2f156f8568035bb5ed8be7bd78289cd9ecbf3381fe8f5b82e3f33 + languageName: node + linkType: hard + "anser@npm:^1.4.9": version: 1.4.10 resolution: "anser@npm:1.4.10" @@ -6208,6 +6220,7 @@ __metadata: "@types/pngjs": "npm:^6.0.5" "@types/react": "npm:~19.1.10" expo: "npm:^54.0.27" + expo-build-properties: "npm:~1.0.10" expo-constants: "npm:~18.0.11" expo-font: "npm:~14.0.10" expo-linking: "npm:~8.0.10" @@ -6226,7 +6239,10 @@ __metadata: react-native-screens: "npm:~4.16.0" react-native-svg: "npm:15.12.1" react-native-svg-transformer: "npm:^1.5.0" + react-native-vision-camera: "npm:^4.7.3" react-native-worklets: "npm:0.5.1" + react-native-worklets-core: "npm:^1.3.3" + vision-camera-resize-plugin: "npm:^3.2.0" languageName: unknown linkType: soft @@ -7654,6 +7670,18 @@ __metadata: languageName: node linkType: hard +"expo-build-properties@npm:~1.0.10": + version: 1.0.10 + resolution: "expo-build-properties@npm:1.0.10" + dependencies: + ajv: "npm:^8.11.0" + semver: "npm:^7.6.0" + peerDependencies: + expo: "*" + checksum: 10/0dde41d659d243268ceae49bba3e4c07b72c245df8124f86fb720bc0556a2c4d03dd75e59e068a07438ef5ba3188b67a7a6516d2a37d3d91429070745b2506a2 + languageName: node + linkType: hard + "expo-calendar@npm:~15.0.8": version: 15.0.8 resolution: "expo-calendar@npm:15.0.8" @@ -7932,6 +7960,13 @@ __metadata: languageName: node linkType: hard +"fast-uri@npm:^3.0.1": + version: 3.1.0 + resolution: "fast-uri@npm:3.1.0" + checksum: 10/818b2c96dc913bcf8511d844c3d2420e2c70b325c0653633f51821e4e29013c2015387944435cd0ef5322c36c9beecc31e44f71b257aeb8e0b333c1d62bb17c2 + languageName: node + linkType: hard + "fast-xml-parser@npm:^4.4.1": version: 4.5.3 resolution: "fast-xml-parser@npm:4.5.3" @@ -9971,6 +10006,13 @@ __metadata: languageName: node linkType: hard +"json-schema-traverse@npm:^1.0.0": + version: 1.0.0 + resolution: "json-schema-traverse@npm:1.0.0" + checksum: 10/02f2f466cdb0362558b2f1fd5e15cce82ef55d60cd7f8fa828cf35ba74330f8d767fcae5c5c2adb7851fa811766c694b9405810879bc4e1ddd78a7c0e03658ad + languageName: node + linkType: hard + "json-stable-stringify-without-jsonify@npm:^1.0.1": version: 1.0.1 resolution: "json-stable-stringify-without-jsonify@npm:1.0.1" @@ -13119,6 +13161,38 @@ __metadata: languageName: node linkType: hard +"react-native-vision-camera@npm:^4.7.3": + version: 4.7.3 + resolution: "react-native-vision-camera@npm:4.7.3" + peerDependencies: + "@shopify/react-native-skia": "*" + react: "*" + react-native: "*" + react-native-reanimated: "*" + react-native-worklets-core: "*" + peerDependenciesMeta: + "@shopify/react-native-skia": + optional: true + react-native-reanimated: + optional: true + react-native-worklets-core: + optional: true + checksum: 10/2487d3651cb07918820e1f255480e28fec77f936450570b531b429a1861fad573f6bfd4cdb15a78ea7adb471ee45d08e0006cdac81db4fe35949291f3e12d680 + languageName: node + linkType: hard + +"react-native-worklets-core@npm:^1.3.3": + version: 1.6.2 + resolution: "react-native-worklets-core@npm:1.6.2" + dependencies: + string-hash-64: "npm:^1.0.3" + peerDependencies: + react: "*" + react-native: "*" + checksum: 10/beeb767ac1fe8229d1bd31890253320b0c05df61a9a2498f274f2752429eb2a8c23491a355100c344ea1d988424b6ee61b723c264c134240039fec38a9d7df38 + languageName: node + linkType: hard + "react-native-worklets@npm:0.5.1": version: 0.5.1 resolution: "react-native-worklets@npm:0.5.1" @@ -14193,6 +14267,13 @@ __metadata: languageName: node linkType: hard +"string-hash-64@npm:^1.0.3": + version: 1.0.3 + resolution: "string-hash-64@npm:1.0.3" + checksum: 10/39aab30e05dfe2effe9a0807a4987cedb6c04f4dfcda3c4915add5b90d68676a6c583d2ad51683932ac3546a4f64a5e99136521b8f2eb1956d581bd2947a96a9 + languageName: node + linkType: hard + "string-length@npm:^4.0.1": version: 4.0.2 resolution: "string-length@npm:4.0.2" @@ -15060,6 +15141,18 @@ __metadata: languageName: node linkType: hard +"vision-camera-resize-plugin@npm:^3.2.0": + version: 3.2.0 + resolution: "vision-camera-resize-plugin@npm:3.2.0" + peerDependencies: + react: "*" + react-native: "*" + react-native-vision-camera: ">=4.0.1" + react-native-worklets-core: ">=1.2.0" + checksum: 10/39f06e35e0fb92e77815d97d7967aec84c4875c99128d7178f2784a9b0f2967e76fc7bb608768d18ee618a64ae899e0318d688c3c9f8556534876b295f008244 + languageName: node + linkType: hard + "vlq@npm:^1.0.0": version: 1.0.1 resolution: "vlq@npm:1.0.1"