add point model#1381
Conversation
d966532 to
a7a091c
Compare
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1381 +/- ##
==========================================
+ Coverage 86.92% 87.61% +0.69%
==========================================
Files 24 26 +2
Lines 3204 3537 +333
==========================================
+ Hits 2785 3099 +314
- Misses 419 438 +19
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
67dabb3 to
74f70c4
Compare
|
Adding a doc page now, but this is good to review. The coverage gaps are training related - I'll add a sanity test case that'll catch those for now. Hmm. Maybe we want to call the checkpoint |
|
@ethanwhite I haven't tapped you for a review, but probably good to have everyone's eyes on this. |
|
waiting on the metrics to review. |
74f70c4 to
b4c54dd
Compare
|
Metrics added, I really thought we'd merged that already but I think just the dataset... @bw4sz ready for you, and then I'll squash. To avoid confusion, I've renamed all everything to Prediction results tested with: And Treeformer's KCL paper results I've tested as (I didn't push the checkpoint, but might as well): |
e51661c to
775b397
Compare
|
Once #1384 is in, I'll refactor the test here. |
8727945 to
0490747
Compare
0490747 to
e98fede
Compare
bw4sz
left a comment
There was a problem hiding this comment.
Looks great. We should probably call this 'point' and not 'keypoint' anywhere, since keypoint makes me think of deeplabcut style things and I think our target user will be happier with point. Next step is to see what the score is on the MillionTrees eval. Great work!
Point prediction model
This PR adds a model that can be used for keypoint/point detection. The (segmentation) model is trained to output a 1/4 size density map with sharp peaks corresponding to training data (i.e. trees). A local maximum detector is run over the map and these peaks are returned to the user, with scores integrated from a small radius. There is no technical limit on the number of predictions per forward pass, and the model seems to perform well in tiled scenarios.
The architecture is based on Treeformer, using the supervised branch, which in turn is a modification of DM-Count and some other past work. I've trained checkpoints on the authors' KCL dataset which replicates their results, and a pretrained checkpoint on NEON Lidar which seems to transfer quite well to other unseen datasets, like some of our test imagery in the everglades.
Default checkpoint is here: https://huggingface.co/weecology/deepforest-keypoint/tree/main and can be loaded with
deepforest(config='keypoint').This branch also includes tested inference code (I've checked tiling) and refactors some of the prediction post-processing to handle the different geometries a little more cleanly. The upshot is
mainis largely unchanged.I'm not a huge fan of adding scikit-image as a dependency, but their peak detection algorithm (
peak_local_max) is very good and I've not found a suitable alternative. There is a discussion on GH about this, but no conclusion.This branch does not have training support (forward doesn't return a loss), which is in a separate branch and will be split out separately for ease of review.
This example is from Hidden, resampled to 5 cm/px, but probably needs a tweak to the NMS (peak separation) threshold.
Replication
Sample predictions trained and tested on the paper's KCL dataset (orange open circles are gt, annotated Google Earth imagery at 0.2 m/px):
Note: this dataset is probably overfit. The paper reports 500 epochs, but the dataset is only 400 images. Nevertheless, the implementation here beats the paper benchmark in supervised mode and also the unsupervised benchmark (likely due to backbone pretraining) based on MAE (paper: 16.7/18.5, ours: 18.24 using density sum / <15 with peak extraction). We get a peak F1 of around 0.68. Convergence is seen after around 200 epochs, but we train for 500 to match the original hyperparameters.
Predictions after a single epoch on NEON LIDAR (purple are preds). The model used for the prediction above was trained for 10 epochs and will be updated when we have a final version as it still had training headroom.
Model differences
Licensing
TreeFormer does not list a license, and the ancestral code from DM-Count (which is mostly copied verbatim in TreeFormer) is MIT-licensed. Here, the backbone model is taken directly from
transformersand the various heads are re-implemented in PyTorch with some optimization from the original.Related Issue(s)
#809
AI-Assisted Development
AI tools used (if applicable):
Claude (Opus/Sonnet 4.6) and Codex 5.3/5.4