Skip to content

WIP: Treeformer#1371

Closed
jveitchmichaelis wants to merge 101 commits into
weecology:mainfrom
jveitchmichaelis:treeformer
Closed

WIP: Treeformer#1371
jveitchmichaelis wants to merge 101 commits into
weecology:mainfrom
jveitchmichaelis:treeformer

Conversation

@jveitchmichaelis
Copy link
Copy Markdown
Collaborator

@jveitchmichaelis jveitchmichaelis commented Apr 14, 2026

Description

This is marked as WIP as there's a lot to clean up, but the model architecture and training loop seem to be OK. It's mostly additive cruft than things that will require extensive rebasing...

Adds a point-detection model based on the TreeFormer architecture in https://arxiv.org/abs/2307.06118 which is based on DM-Count https://arxiv.org/abs/2009.13077.

This contribution is probably more in spirit with DM-Count than TreeFormer because it only implements the supervised path, but the code was heavily adapted from the TreeFormer repository.

The model is a segmentation backbone (here, PvTv2) which feeds into two heads. One, a Global Average Pool (GAP) which predicts a scalar tree count, and the other a density regression head (a multi-scale decoder).

TODO: add model checkpoints to huggingface and default to them when user picks the treeformer config.
TODO: big clean up of commits, obviously.
TODO: remove slurm scripts, etc.

Replication

Sample predictions trained and tested on the paper's KCL dataset (using Google Earth images):

IMG_158 density plot-6500

Note: this dataset is probably overfit. The paper reports 500 epochs, but the dataset is only 400 images. Nevertheless, the implementation here beats the paper benchmark in supervised mode and also the unsupervised benchmark (likely due to backbone pretraining) based on MAE (paper: 16.7/18.5, ours: 18.24 using density sum / <15 with peak extraction). We get a peak F1 of around 0.68. Convergence is seen after around 200 epochs, but we train for 500 to match the original hyperparameters.

image

A single epoch on NEON LIDAR (purple are preds):

prediction sample-825 prediction sample-825_2

Model differences

  • TreeFormer is, as far as I can tell, not invariant to input image shape. The vanilla implementation trains on 256 crops and then predicts over 256 tiles. If you change the input shape at test time, the model will under or over-predict tree density. This implementation changes the output to regress density instead of absolute count, and learns a scale factor that is applied to the input image. This does only affect scale factor, but it's nice to be consistent here.
  • In the original paper the model learns spatial structure and count separately. In this version, we couple the two together by forcing the sum of the density map to match the object count for consistency. Typically during training, the model will quickly learn how to place tree "points" and then will slowly learn to correct the predict count output.
  • Support for DDP is provided via DeepForest, which involved a few tweaks to various parts of the codebase.

PRs to be split - will update

Licensing

TreeFormer does not list a license, and the ancestral code from DM-Count (which is mostly copied verbatim in TreeFormer) is MIT-licensed. Here, the backbone model is taken directly from transformers and the various heads are re-implemented in PyTorch with some optimization from the original.

Related Issue(s)

#809

AI-Assisted Development

  • I used AI tools (e.g., GitHub Copilot, ChatGPT, etc.) in developing this PR
  • I understand all the code I'm submitting
  • I have reviewed and validated all AI-generated code

AI tools used (if applicable):

Claude (Opus/Sonnet 4.6) and Codex 5.3/5.4

@jveitchmichaelis
Copy link
Copy Markdown
Collaborator Author

Closing in favour of simpler branch

@jveitchmichaelis jveitchmichaelis mentioned this pull request Apr 22, 2026
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant