Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
851b01d
Merge pull request #13 from deruyter92/jaap/add_config_and_registry
xiu-cs Feb 10, 2026
799be4d
Refactor FMPose3D test script to use model_type instead of model_path
xiu-cs Feb 10, 2026
7acba1f
Refactor FMPose3D_main.py to load model using get_model from registry…
xiu-cs Feb 10, 2026
f173935
Update FMPose3D_train.sh to use model_type for model selection instea…
xiu-cs Feb 10, 2026
4f07856
Merge branch 'feat/add_api' into ti_video_demo
xiu-cs Feb 10, 2026
f878e7f
Add model_type argument to opts for model registry selection
xiu-cs Feb 10, 2026
321b28a
Remove unnecessary comment block in HRNet implementation file
xiu-cs Feb 10, 2026
ea1d3f7
Import animal models to ensure their registration in the model registry.
xiu-cs Feb 10, 2026
dd5be5d
Register Model class for animal3D in the model registry and update in…
xiu-cs Feb 10, 2026
a7e25e9
Refactor main_animal3d.py to load model from the registered model reg…
xiu-cs Feb 10, 2026
f1edf44
Update test_animal3d.sh to modify eval_sample_steps, change model_typ…
xiu-cs Feb 10, 2026
51b14ab
Update train_animal3d.sh to modify eval_sample_steps, change model_ty…
xiu-cs Feb 10, 2026
4702481
Refactor vis_animals.py to load model using get_model from the regist…
xiu-cs Feb 10, 2026
b02deda
Update vis_animals.sh to set model_type for FMPose3D and adjust saved…
xiu-cs Feb 10, 2026
16f340c
Update README.md with new download link for pre-trained model
xiu-cs Feb 10, 2026
0512c27
Refactor backup file handling in main_animal3d.py to enable file copy…
xiu-cs Feb 11, 2026
a440a08
Update test_animal3d.sh to change test dataset path and adjust script…
xiu-cs Feb 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion animals/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ In this part, the FMPose3D model is trained on [Animal3D](https://xujiacong.gith
This visualization script is designed for single-frame based model, allowing you to easily run 3D animal pose estimation on any single image.

Before testing, make sure you have the pre-trained model ready.
You may either use the model trained by your own or download ours from [here](https://drive.google.com/drive/folders/1fMKVaYziwFkAnFrtQZmoPOTfe7Hkl2at?usp=sharing) and place it in the `./pre_trained_models` directory.
You may either use the model trained by your own or download ours from [here](https://drive.google.com/drive/folders/1kL4aOyWNq0o9zB0rSTRM8KYgkySVmUTk?usp=drive_link) and place it in the `./pre_trained_models` directory.

Next, put your test images into folder `demo/images`. Then run the visualization script:
```bash
Expand Down
5 changes: 3 additions & 2 deletions animals/demo/vis_animals.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@
spec.loader.exec_module(module)
CFM = getattr(module, "Model")
else:
# Load model from installed fmpose package
from fmpose3d.models import Model as CFM
# Load model from registered model registry
from fmpose3d.models import get_model
CFM = get_model(args.model_type)

from deeplabcut.pose_estimation_pytorch.apis import superanimal_analyze_images

Expand Down
8 changes: 5 additions & 3 deletions animals/demo/vis_animals.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ sh_file='vis_animals.sh'
# n_joints=26
# out_joints=26

model_path='../pre_trained_models/animal3d_pretrained_weights/model_animal3d.py'
saved_model_path='../pre_trained_models/animal3d_pretrained_weights/CFM_154_4403_best.pth'
model_type='fmpose3d_animals'
# model_path='' # set to a local file path to override the registry
saved_model_path='../pre_trained_models/fmpose3d_animals/fmpose3d_animals_pretrained_weights.pth'

# path='./images/image_00068.jpg' # single image
input_images_folder='./images/' # folder containing multiple images
Expand All @@ -17,7 +18,8 @@ python3 vis_animals.py \
--type 'image' \
--path ${input_images_folder} \
--saved_model_path "${saved_model_path}" \
--model_path "${model_path}" \
${model_path:+--model_path "$model_path"} \
--model_type "${model_type}" \
--sample_steps ${sample_steps} \
--batch_size ${batch_size} \
--layers ${layers} \
Expand Down
33 changes: 15 additions & 18 deletions animals/scripts/main_animal3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@
spec.loader.exec_module(module)
CFM = getattr(module, "Model")
else:
# Load model from installed fmpose package
from fmpose3d.animals.models import Model as CFM
# Load model from registered model registry
from fmpose3d.models import get_model
CFM = get_model(args.model_type)

def train(opt, actions, train_loader, model, optimizer, epoch):
return step('train', opt, actions, train_loader, model, optimizer, epoch)
Expand Down Expand Up @@ -98,7 +99,6 @@ def step(split, args, actions, dataLoader, model, optimizer=None, epoch=None, st
gt_3D = gt_3D.clone()
gt_3D[:, :, args.root_joint] = 0


# Conditional Flow Matching training
# gt_3D, input_2D shape: (B,F,J,C)
# vis_3D shape: (B,F,J,1) - visibility mask
Expand Down Expand Up @@ -217,21 +217,18 @@ def get_parameter_number(net):
os.makedirs(args.checkpoint)

# backup files
# import shutil
# file_path = os.path.abspath(__file__)
# file_name = os.path.basename(file_path)
# shutil.copyfile(src=file_path, dst=os.path.join(args.checkpoint, args.create_time + "_" + file_name))
# shutil.copyfile(src=os.path.abspath("common/arguments.py"), dst=os.path.join(args.checkpoint, args.create_time + "_arguments.py"))
# # backup the selected model file (from --model_path if provided)
# if getattr(args, 'model_path', ''):
# model_src_path = os.path.abspath(args.model_path)
# model_dst_name = f"{args.create_time}_" + os.path.basename(model_src_path)
# shutil.copyfile(src=model_src_path, dst=os.path.join(args.checkpoint, model_dst_name))
# # shutil.copyfile(src="common/utils.py", dst = os.path.join(args.checkpoint, args.create_time + "_utils.py"))
# sh_base = os.path.basename(args.sh_file)
# dst_name = f"{args.create_time}_" + sh_base
# sh_src = os.path.abspath(args.sh_file)
# shutil.copyfile(src=sh_src, dst=os.path.join(args.checkpoint, dst_name))
import shutil
file_path = os.path.abspath(__file__)
file_name = os.path.basename(file_path)
shutil.copyfile(src=file_path, dst=os.path.join(args.checkpoint, args.create_time + "_" + file_name))
if getattr(args, 'model_path', ''):
model_src_path = os.path.abspath(args.model_path)
model_dst_name = f"{args.create_time}_" + os.path.basename(model_src_path)
shutil.copyfile(src=model_src_path, dst=os.path.join(args.checkpoint, model_dst_name))
sh_base = os.path.basename(args.sh_file)
dst_name = f"{args.create_time}_" + sh_base
sh_src = os.path.abspath(args.sh_file)
shutil.copyfile(src=sh_src, dst=os.path.join(args.checkpoint, dst_name))

logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%Y/%m/%d %H:%M:%S', \
filename=os.path.join(args.checkpoint, 'train.log'), level=logging.INFO)
Expand Down
12 changes: 7 additions & 5 deletions animals/scripts/test_animal3d.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@ layers=5
batch_size=13
lr=1e-3
gpu_id=0
eval_sample_steps=3
eval_sample_steps=5
num_saved_models=3
frames=1
large_decay_epoch=15
lr_decay_large=0.75
n_joints=26
out_joints=26
epochs=300
# model_path='models/model_animals.py'
model_path='./pre_trained_models/animal3d_pretrained_weights/model_animal3d.py' # when the path is empty, the model will be loaded from the installed fmpose package
saved_model_path='./pre_trained_models/animal3d_pretrained_weights/CFM_154_4403_best.pth'
model_type='fmpose3d_animals'
# model_path='' # set to a local file path to override the registry
saved_model_path='./pre_trained_models/fmpose3d_animals/fmpose3d_animals_pretrained_weights.pth'

# root path denotes the path to the original dataset
root_path="./dataset/"
train_dataset_paths=(
Expand All @@ -24,7 +25,7 @@ test_dataset_paths=(
)

folder_name="TestCtrlAni3D_L${layers}_lr${lr}_B${batch_size}_$(date +%Y%m%d_%H%M%S)"
sh_file='scripts/animals/test_animal3d.sh'
sh_file='scripts/test_animal3d.sh'

python ./scripts/main_animal3d.py \
--root_path ${root_path} \
Expand All @@ -33,6 +34,7 @@ python ./scripts/main_animal3d.py \
--test 1 \
--batch_size ${batch_size} \
--lr ${lr} \
--model_type "${model_type}" \
${model_path:+--model_path "$model_path"} \
--folder_name ${folder_name} \
--layers ${layers} \
Expand Down
10 changes: 4 additions & 6 deletions animals/scripts/train_animal3d.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@ layers=5
batch_size=13
lr=1e-3
gpu_id=0
eval_sample_steps=3
eval_sample_steps=5
num_saved_models=3
frames=1
large_decay_epoch=15
lr_decay_large=0.75
n_joints=26
out_joints=26
epochs=300
# model_path='models/model_animals.py'
model_path="" # when the path is empty, the model will be loaded from the installed fmpose package
model_type='fmpose3d_animals'
# model_path="" # set to a local file path to override the registry
# root path denotes the path to the original dataset
root_path="./dataset/"
train_dataset_paths=(
Expand All @@ -32,7 +30,7 @@ python ./scripts/main_animal3d.py \
--test 1 \
--batch_size ${batch_size} \
--lr ${lr} \
${model_path:+--model_path "$model_path"} \
--model_type "${model_type}" \
--folder_name ${folder_name} \
--layers ${layers} \
--gpu ${gpu_id} \
Expand Down
2 changes: 2 additions & 0 deletions fmpose3d/animals/common/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def init(self):
self.parser.add_argument("--model_dir", type=str, default="")
# Optional: load model class from a specific file path
self.parser.add_argument("--model_path", type=str, default="")
# Model registry name (e.g. "fmpose3d_animals"); used instead of --model_path
self.parser.add_argument("--model_type", type=str, default="fmpose3d_animals")

self.parser.add_argument("--post_refine_reload", action="store_true")
self.parser.add_argument("--checkpoint", type=str, default="")
Expand Down
6 changes: 4 additions & 2 deletions fmpose3d/animals/models/model_animal3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from timm.models.layers import DropPath

from fmpose3d.animals.models.graph_frames import Graph
from fmpose3d.models.base_model import BaseModel, register_model

class TimeEmbedding(nn.Module):
def __init__(self, dim: int, hidden_dim: int = 64):
Expand Down Expand Up @@ -207,9 +208,10 @@ def forward(self, x):
x = self.fc5(x)
return x

class Model(nn.Module):
@register_model("fmpose3d_animals")
class Model(BaseModel):
def __init__(self, args):
super().__init__()
super().__init__(args)

self.graph = Graph('animal3d', 'spatial', pad=1)
self.register_buffer('A', torch.tensor(self.graph.A, dtype=torch.float32))
Expand Down
2 changes: 0 additions & 2 deletions fmpose3d/lib/hrnet/hrnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
"FMPose3D: monocular 3D Pose Estimation via Flow Matching"
by Ti Wang, Xiaohang Yu, and Mackenzie Weygandt Mathis
Licensed under Apache 2.0
"""

"""
FMPose3D – clean HRNet 2D pose estimation API.

Provides :class:`HRNetPose2d`, a self-contained wrapper around the
Expand Down
2 changes: 2 additions & 0 deletions fmpose3d/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

# Import model subpackages so their @register_model decorators execute.
from .fmpose3d import Graph, Model
# Import animal models so their @register_model decorators execute.
from fmpose3d.animals import models as _animal_models # noqa: F401

__all__ = [
"BaseModel",
Expand Down
11 changes: 3 additions & 8 deletions scripts/FMPose3D_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@
spec.loader.exec_module(module)
CFM = getattr(module, "Model")
else:
# Load model from installed fmpose package
from fmpose3d.models import Model as CFM
# Load model from registered model registry
from fmpose3d.models import get_model
CFM = get_model(args.model_type)


def test_multi_hypothesis(
Expand Down Expand Up @@ -281,12 +282,6 @@ def print_error_action(action_error_sum, is_train):
src=script_path,
dst=os.path.join(args.checkpoint, args.create_time + "_" + script_name),
)
if getattr(args, "model_path", ""):
model_src_path = os.path.abspath(args.model_path)
model_dst_name = f"{args.create_time}_" + os.path.basename(model_src_path)
shutil.copyfile(
src=model_src_path, dst=os.path.join(args.checkpoint, model_dst_name)
)
sh_base = os.path.basename(args.sh_file)
dst_name = f"{args.create_time}_" + sh_base
shutil.copyfile(src=args.sh_file, dst=os.path.join(args.checkpoint, dst_name))
Expand Down
4 changes: 2 additions & 2 deletions scripts/FMPose3D_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ mode='exp'
exp_temp=0.005
folder_name=test_s${eval_multi_steps}_${mode}_h${num_hypothesis_list}_$(date +%Y%m%d_%H%M%S)

model_path='./pre_trained_models/fmpose3d_h36m/model_GAMLP.py'
model_type='fmpose3d'
model_weights_path='./pre_trained_models/fmpose3d_h36m/FMpose3D_pretrained_weights.pth'

#Test
Expand All @@ -20,7 +20,7 @@ python3 scripts/FMPose3D_main.py \
--exp_temp ${exp_temp} \
--folder_name ${folder_name} \
--model_weights_path "${model_weights_path}" \
--model_path "${model_path}" \
--model_type "${model_type}" \
--eval_sample_steps ${eval_multi_steps} \
--test_augmentation True \
--batch_size ${batch_size} \
Expand Down
4 changes: 2 additions & 2 deletions scripts/FMPose3D_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ epochs=80
num_saved_models=3
frames=1
channel_dim=512
model_path="" # when the path is empty, the model will be loaded from the installed fmpose3d package
# model_path='./models/model_GAMLP.py' # when the path is not empty, the model will be loaded from the local file path
model_type='fmpose3d' # use registered model by default
sh_file='scripts/FMPose3D_train.sh'
folder_name=FMPose3D_layers${layers}_$(date +%Y%m%d_%H%M%S)

python3 scripts/FMPose3D_main.py \
--train \
--dataset h36m \
--frames ${frames} \
--model_type "${model_type}" \
${model_path:+--model_path "$model_path"} \
--gpu ${gpu_id} \
--batch_size ${batch_size} \
Expand Down
Loading