This project focuses on detecting lick events in videos of mice using a temporal deep learning model. The pipeline includes video processing, model training, prediction, and evaluation.
config.py: Contains all configuration settings for the project, including paths, model parameters, and processing settings.model_architecture.py: Defines the neural network architecture for the temporal lick classifier.train_temporal_lick_classifier.py: Script for training the temporal lick classification model.predict_temporal_lick_events.py: Script to predict lick events in video frames using a trained temporal model. It can process pre-extracted frames or extract frames from videos on-the-fly.video.py: Contains utility functions for video processing, including frame extraction, prediction visualization, and loading trained models.split.py: Script for splitting the dataset into training and testing sets and extracting frames from videos.evaluate_all_animals.py: Script for evaluating the model performance on pre-extracted frames for all animals in the dataset.evaluate_temporal_model.py: Script for evaluating the temporal model, likely a more general version or an older iteration of the evaluation scripts.evaluate_video_model.py: A newer evaluation script that works directly with raw video files, extracts frames, runs predictions, and generates detailed reports and visualizations.model_architecture.png: A visual representation of the model architecture.
models/: Stores trained model checkpoints.labeled_videos/: Output directory for videos with visualized predictions.evaluation_results/: Output directory for evaluation summaries, plots, and videos generated byevaluate_video_model.py.runs/: Typically used by TensorBoard to store training logs.temp_frames/: Temporary directory used byevaluate_video_model.pyto store extracted frames during processing.
- Configuration (
config.py): All project-wide settings are centralized here. This includes data paths, model hyperparameters (sequence length, LSTM hidden size, etc.), and evaluation thresholds. - Data Preparation (
split.py):- Loads a main dataset (e.g.,
swallow_lick_breath_tracking_dataset.pkl) containing information about animal IDs, session dates, trials, and lick event labels. - Extracts frames from raw video files (e.g.,
.mp4fromI:/side/ANIMAL_ID/SESSION_DATE/). - Splits the data into training and testing sets and saves label files (e.g.,
test_labels_lick_ANIMAL_ID.pkl).
- Loads a main dataset (e.g.,
- Model Architecture (
model_architecture.py):- Defines a temporal model, likely using a CNN backbone (e.g., ResNet18) to extract features from individual frames, followed by an LSTM or similar recurrent layer to capture temporal dependencies across a sequence of frames.
- Model Training (
train_temporal_lick_classifier.py):- Loads frame sequences and their corresponding labels.
- Trains the temporal model defined in
model_architecture.py. - Saves the best performing model checkpoint to the
models/directory. - Uses TensorBoard for logging training progress.
- Prediction (
predict_temporal_lick_events.py,video.py):- Loads a trained model.
- Can take a path to a folder of pre-extracted frames or a path to a video file.
- If a video file is provided, it extracts frames.
- Creates sequences of frames.
- Performs inference on these sequences to predict lick probabilities for each frame (specifically, the last frame of each sequence window).
- Outputs predictions, which can be saved or used for visualization.
- Evaluation (
evaluate_video_model.py,evaluate_all_animals.py):evaluate_video_model.py:- Loads the main dataset to get video paths and labels for all animals.
- For each animal, samples a number of videos.
- Extracts frames from these videos into a temporary directory.
- Runs predictions using the trained model.
- Calculates metrics (accuracy, precision, recall, F1-score, positive accuracy).
- Generates plots comparing predictions to ground truth.
- If accuracy is below a threshold, it creates labeled videos showing predictions overlaid on the original video frames.
- Saves a summary of results to CSV files in
evaluation_results/.
evaluate_all_animals.py: An older script that evaluates based on pre-split frames, referencing label files generated bysplit.py.
- Configure Paths: Ensure all paths in
config.py(e.g.,DATA_ROOT,BEST_MODEL_PATH, dataset pickle file path, video root pathI:/side) are correctly set for your environment. - Prepare Data:
- Run
split.pyif you need to pre-extract frames and generate label files. This step might be optional ifevaluate_video_model.pyis used primarily, as it handles frame extraction on-the-fly.
- Run
- Train Model:
- Run
train_temporal_lick_classifier.pyto train the model. Ensure your dataset is prepared and paths inconfig.pyare correct.
- Run
- Run Predictions:
- Use
predict_temporal_lick_events.py --frames <path_to_frames_folder>orpredict_temporal_lick_events.py --video <path_to_video_file>.
- Use
- Evaluate Model:
- Run
python evaluate_video_model.pyto evaluate the model on raw videos. This script will handle frame extraction, prediction, and reporting. - Optionally, use
python evaluate_all_animals.pyif you are working with pre-extracted frames.
- Run
- Python 3.x
- PyTorch
- torchvision
- OpenCV (cv2)
- NumPy
- scikit-learn
- Matplotlib
- Pandas
- tqdm
To install dependencies (example):
pip install torch torchvision opencv-python numpy scikit-learn matplotlib pandas tqdm- The system is designed to map predictions to the last frame in a sequence window.
- It handles cases where videos might have fewer frames than the required sequence length by padding.
- The evaluation scripts can create visualizations and save detailed metrics.
- The primary video data is expected to be in
I:/side/. - The main dataset metadata is in
C:/Users/Nuo Lab/Desktop/lick_proj/swallow_lick_breath_tracking_dataset.pkl.
This README provides a general overview. For specific details, refer to the source code and comments within each Python file.