diff --git a/.gitignore b/.gitignore index 90439ac..3f7f92b 100644 --- a/.gitignore +++ b/.gitignore @@ -49,4 +49,5 @@ htmlcov/ # Excluded directories pre_trained_models/ demo/predictions/ -demo/images/ \ No newline at end of file +demo/images/ +**/predictions/ \ No newline at end of file diff --git a/animals/README.md b/animals/README.md index af9da69..5121533 100644 --- a/animals/README.md +++ b/animals/README.md @@ -10,7 +10,7 @@ In this part, the FMPose3D model is trained on [Animal3D](https://xujiacong.gith This visualization script is designed for single-frame based model, allowing you to easily run 3D animal pose estimation on any single image. Before testing, make sure you have the pre-trained model ready. -You may either use the model trained by your own or download ours from [here](https://drive.google.com/drive/folders/1fMKVaYziwFkAnFrtQZmoPOTfe7Hkl2at?usp=sharing) and place it in the `./pre_trained_models` directory. +You may either use the model trained by your own or download ours from [here](https://drive.google.com/drive/folders/1kL4aOyWNq0o9zB0rSTRM8KYgkySVmUTk?usp=drive_link) and place it in the `./pre_trained_models` directory. Next, put your test images into folder `demo/images`. Then run the visualization script: ```bash diff --git a/animals/demo/vis_animals.py b/animals/demo/vis_animals.py index c9fe438..459c2bc 100644 --- a/animals/demo/vis_animals.py +++ b/animals/demo/vis_animals.py @@ -42,8 +42,9 @@ spec.loader.exec_module(module) CFM = getattr(module, "Model") else: - # Load model from installed fmpose package - from fmpose3d.models import Model as CFM + # Load model from registered model registry + from fmpose3d.models import get_model + CFM = get_model(args.model_type) from deeplabcut.pose_estimation_pytorch.apis import superanimal_analyze_images diff --git a/animals/demo/vis_animals.sh b/animals/demo/vis_animals.sh index 3879391..e2944c2 100644 --- a/animals/demo/vis_animals.sh +++ b/animals/demo/vis_animals.sh @@ -7,8 +7,9 @@ sh_file='vis_animals.sh' # n_joints=26 # out_joints=26 -model_path='../pre_trained_models/animal3d_pretrained_weights/model_animal3d.py' -saved_model_path='../pre_trained_models/animal3d_pretrained_weights/CFM_154_4403_best.pth' +model_type='fmpose3d_animals' +# model_path='' # set to a local file path to override the registry +saved_model_path='../pre_trained_models/fmpose3d_animals/fmpose3d_animals_pretrained_weights.pth' # path='./images/image_00068.jpg' # single image input_images_folder='./images/' # folder containing multiple images @@ -17,7 +18,8 @@ python3 vis_animals.py \ --type 'image' \ --path ${input_images_folder} \ --saved_model_path "${saved_model_path}" \ - --model_path "${model_path}" \ + ${model_path:+--model_path "$model_path"} \ + --model_type "${model_type}" \ --sample_steps ${sample_steps} \ --batch_size ${batch_size} \ --layers ${layers} \ diff --git a/animals/scripts/main_animal3d.py b/animals/scripts/main_animal3d.py index 0436204..c90bdea 100644 --- a/animals/scripts/main_animal3d.py +++ b/animals/scripts/main_animal3d.py @@ -38,8 +38,9 @@ spec.loader.exec_module(module) CFM = getattr(module, "Model") else: - # Load model from installed fmpose package - from fmpose3d.animals.models import Model as CFM + # Load model from registered model registry + from fmpose3d.models import get_model + CFM = get_model(args.model_type) def train(opt, actions, train_loader, model, optimizer, epoch): return step('train', opt, actions, train_loader, model, optimizer, epoch) @@ -98,7 +99,6 @@ def step(split, args, actions, dataLoader, model, optimizer=None, epoch=None, st gt_3D = gt_3D.clone() gt_3D[:, :, args.root_joint] = 0 - # Conditional Flow Matching training # gt_3D, input_2D shape: (B,F,J,C) # vis_3D shape: (B,F,J,1) - visibility mask @@ -217,21 +217,18 @@ def get_parameter_number(net): os.makedirs(args.checkpoint) # backup files - # import shutil - # file_path = os.path.abspath(__file__) - # file_name = os.path.basename(file_path) - # shutil.copyfile(src=file_path, dst=os.path.join(args.checkpoint, args.create_time + "_" + file_name)) - # shutil.copyfile(src=os.path.abspath("common/arguments.py"), dst=os.path.join(args.checkpoint, args.create_time + "_arguments.py")) - # # backup the selected model file (from --model_path if provided) - # if getattr(args, 'model_path', ''): - # model_src_path = os.path.abspath(args.model_path) - # model_dst_name = f"{args.create_time}_" + os.path.basename(model_src_path) - # shutil.copyfile(src=model_src_path, dst=os.path.join(args.checkpoint, model_dst_name)) - # # shutil.copyfile(src="common/utils.py", dst = os.path.join(args.checkpoint, args.create_time + "_utils.py")) - # sh_base = os.path.basename(args.sh_file) - # dst_name = f"{args.create_time}_" + sh_base - # sh_src = os.path.abspath(args.sh_file) - # shutil.copyfile(src=sh_src, dst=os.path.join(args.checkpoint, dst_name)) + import shutil + file_path = os.path.abspath(__file__) + file_name = os.path.basename(file_path) + shutil.copyfile(src=file_path, dst=os.path.join(args.checkpoint, args.create_time + "_" + file_name)) + if getattr(args, 'model_path', ''): + model_src_path = os.path.abspath(args.model_path) + model_dst_name = f"{args.create_time}_" + os.path.basename(model_src_path) + shutil.copyfile(src=model_src_path, dst=os.path.join(args.checkpoint, model_dst_name)) + sh_base = os.path.basename(args.sh_file) + dst_name = f"{args.create_time}_" + sh_base + sh_src = os.path.abspath(args.sh_file) + shutil.copyfile(src=sh_src, dst=os.path.join(args.checkpoint, dst_name)) logging.basicConfig(format='%(asctime)s %(message)s', datefmt='%Y/%m/%d %H:%M:%S', \ filename=os.path.join(args.checkpoint, 'train.log'), level=logging.INFO) diff --git a/animals/scripts/test_animal3d.sh b/animals/scripts/test_animal3d.sh index 3c85885..207e332 100644 --- a/animals/scripts/test_animal3d.sh +++ b/animals/scripts/test_animal3d.sh @@ -2,7 +2,7 @@ layers=5 batch_size=13 lr=1e-3 gpu_id=0 -eval_sample_steps=3 +eval_sample_steps=5 num_saved_models=3 frames=1 large_decay_epoch=15 @@ -10,9 +10,10 @@ lr_decay_large=0.75 n_joints=26 out_joints=26 epochs=300 -# model_path='models/model_animals.py' -model_path='./pre_trained_models/animal3d_pretrained_weights/model_animal3d.py' # when the path is empty, the model will be loaded from the installed fmpose package -saved_model_path='./pre_trained_models/animal3d_pretrained_weights/CFM_154_4403_best.pth' +model_type='fmpose3d_animals' +# model_path='' # set to a local file path to override the registry +saved_model_path='./pre_trained_models/fmpose3d_animals/fmpose3d_animals_pretrained_weights.pth' + # root path denotes the path to the original dataset root_path="./dataset/" train_dataset_paths=( @@ -24,7 +25,7 @@ test_dataset_paths=( ) folder_name="TestCtrlAni3D_L${layers}_lr${lr}_B${batch_size}_$(date +%Y%m%d_%H%M%S)" -sh_file='scripts/animals/test_animal3d.sh' +sh_file='scripts/test_animal3d.sh' python ./scripts/main_animal3d.py \ --root_path ${root_path} \ @@ -33,6 +34,7 @@ python ./scripts/main_animal3d.py \ --test 1 \ --batch_size ${batch_size} \ --lr ${lr} \ + --model_type "${model_type}" \ ${model_path:+--model_path "$model_path"} \ --folder_name ${folder_name} \ --layers ${layers} \ diff --git a/animals/scripts/train_animal3d.sh b/animals/scripts/train_animal3d.sh index cdd6f8c..bcba1ee 100644 --- a/animals/scripts/train_animal3d.sh +++ b/animals/scripts/train_animal3d.sh @@ -2,16 +2,14 @@ layers=5 batch_size=13 lr=1e-3 gpu_id=0 -eval_sample_steps=3 +eval_sample_steps=5 num_saved_models=3 frames=1 large_decay_epoch=15 lr_decay_large=0.75 -n_joints=26 -out_joints=26 epochs=300 -# model_path='models/model_animals.py' -model_path="" # when the path is empty, the model will be loaded from the installed fmpose package +model_type='fmpose3d_animals' +# model_path="" # set to a local file path to override the registry # root path denotes the path to the original dataset root_path="./dataset/" train_dataset_paths=( @@ -32,7 +30,7 @@ python ./scripts/main_animal3d.py \ --test 1 \ --batch_size ${batch_size} \ --lr ${lr} \ - ${model_path:+--model_path "$model_path"} \ + --model_type "${model_type}" \ --folder_name ${folder_name} \ --layers ${layers} \ --gpu ${gpu_id} \ diff --git a/fmpose3d/__init__.py b/fmpose3d/__init__.py index 563a140..7f23320 100644 --- a/fmpose3d/__init__.py +++ b/fmpose3d/__init__.py @@ -36,6 +36,9 @@ Source, ) +# Model registry +from .models import BaseModel, register_model, get_model, list_models + # Import 2D pose detection utilities from .lib.hrnet.gen_kpts import gen_video_kpts from .lib.hrnet.hrnet import HRNetPose2d @@ -59,6 +62,11 @@ "average_aggregation", "aggregation_select_single_best_hypothesis_by_2D_error", "aggregation_RPEA_joint_level", + # Model registry + "BaseModel", + "register_model", + "get_model", + "list_models", # 2D pose detection "HRNetPose2d", "gen_video_kpts", diff --git a/fmpose3d/animals/common/arguments.py b/fmpose3d/animals/common/arguments.py index 690c540..f465f17 100755 --- a/fmpose3d/animals/common/arguments.py +++ b/fmpose3d/animals/common/arguments.py @@ -68,6 +68,8 @@ def init(self): self.parser.add_argument("--model_dir", type=str, default="") # Optional: load model class from a specific file path self.parser.add_argument("--model_path", type=str, default="") + # Model registry name (e.g. "fmpose3d_animals"); used instead of --model_path + self.parser.add_argument("--model_type", type=str, default="fmpose3d_animals") self.parser.add_argument("--post_refine_reload", action="store_true") self.parser.add_argument("--checkpoint", type=str, default="") diff --git a/fmpose3d/animals/models/model_animal3d.py b/fmpose3d/animals/models/model_animal3d.py index 273b1dd..7ea8b95 100644 --- a/fmpose3d/animals/models/model_animal3d.py +++ b/fmpose3d/animals/models/model_animal3d.py @@ -16,6 +16,7 @@ from timm.models.layers import DropPath from fmpose3d.animals.models.graph_frames import Graph +from fmpose3d.models.base_model import BaseModel, register_model class TimeEmbedding(nn.Module): def __init__(self, dim: int, hidden_dim: int = 64): @@ -207,9 +208,10 @@ def forward(self, x): x = self.fc5(x) return x -class Model(nn.Module): +@register_model("fmpose3d_animals") +class Model(BaseModel): def __init__(self, args): - super().__init__() + super().__init__(args) self.graph = Graph('animal3d', 'spatial', pad=1) self.register_buffer('A', torch.tensor(self.graph.A, dtype=torch.float32)) diff --git a/fmpose3d/common/config.py b/fmpose3d/common/config.py index b2980e1..39cbee5 100644 --- a/fmpose3d/common/config.py +++ b/fmpose3d/common/config.py @@ -9,8 +9,7 @@ import math from dataclasses import dataclass, field, fields, asdict -from typing import List - +from typing import Dict, List # --------------------------------------------------------------------------- # Dataclass configuration groups @@ -23,21 +22,60 @@ class ModelConfig: model_type: str = "fmpose3d" +# Per-model-type defaults for fields marked with INFER_FROM_MODEL_TYPE. +# Also consumed by PipelineConfig.for_model_type to set cross-config +# values (dataset, sample_steps, etc.). +_FMPOSE3D_DEFAULTS: Dict[str, Dict] = { + "fmpose3d": { + "n_joints": 17, + "out_joints": 17, + "dataset": "h36m", + "sample_steps": 3, + "joints_left": [4, 5, 6, 11, 12, 13], + "joints_right": [1, 2, 3, 14, 15, 16], + "root_joint": 0, + }, + "fmpose3d_animals": { + "n_joints": 26, + "out_joints": 26, + "dataset": "animal3d", + "sample_steps": 3, + "joints_left": [0, 3, 5, 8, 10, 12, 14, 16, 20, 22], + "joints_right": [1, 4, 6, 9, 11, 13, 15, 17, 21, 23], + "root_joint": 7, + }, +} + +# Sentinel object for defaults that are inferred from the model type. +INFER_FROM_MODEL_TYPE = object() + @dataclass class FMPose3DConfig(ModelConfig): - model: str = "" model_type: str = "fmpose3d" - layers: int = 3 + model: str = "" + layers: int = 5 channel: int = 512 d_hid: int = 1024 token_dim: int = 256 - n_joints: int = 17 - out_joints: int = 17 + n_joints: int = INFER_FROM_MODEL_TYPE # type: ignore[assignment] + out_joints: int = INFER_FROM_MODEL_TYPE # type: ignore[assignment] + joints_left: List[int] = INFER_FROM_MODEL_TYPE # type: ignore[assignment] + joints_right: List[int] = INFER_FROM_MODEL_TYPE # type: ignore[assignment] + root_joint: int = INFER_FROM_MODEL_TYPE # type: ignore[assignment] in_channels: int = 2 out_channels: int = 3 frames: int = 1 - """Optional: load model class from a specific file path.""" + def __post_init__(self): + defaults = _FMPOSE3D_DEFAULTS.get(self.model_type) + if defaults is None: + supported = ", ".join(sorted(_FMPOSE3D_DEFAULTS)) + raise ValueError( + f"Unknown model_type {self.model_type!r}; supported: {supported}" + ) + for f in fields(self): + if getattr(self, f.name) is INFER_FROM_MODEL_TYPE: + setattr(self, f.name, defaults[f.name]) @dataclass class DatasetConfig: @@ -178,6 +216,33 @@ class HRNetConfig(Pose2DConfig): hrnet_weights_path: str = "" +@dataclass +class SuperAnimalConfig(Pose2DConfig): + """DeepLabCut SuperAnimal 2D pose detector configuration. + + Uses the DeepLabCut ``superanimal_analyze_images`` API to detect + animal keypoints in the quadruped80K format, then maps them to the + Animal3D 26-keypoint layout expected by the ``fmpose3d_animals`` + 3D lifter. + + Attributes + ---------- + superanimal_name : str + Name of the SuperAnimal model (default ``"superanimal_quadruped"``). + sa_model_name : str + Backbone architecture (default ``"hrnet_w32"``). + detector_name : str + Object detector used for animal bounding boxes. + max_individuals : int + Maximum number of individuals to detect per image (default 1). + """ + pose2d_model: str = "superanimal" + superanimal_name: str = "superanimal_quadruped" + sa_model_name: str = "hrnet_w32" + detector_name: str = "fasterrcnn_resnet50_fpn_v2" + max_individuals: int = 1 + + @dataclass class DemoConfig: """Demo / inference configuration.""" @@ -239,8 +304,6 @@ class PipelineConfig: demo_cfg: DemoConfig = field(default_factory=DemoConfig) runtime_cfg: RuntimeConfig = field(default_factory=RuntimeConfig) - # -- construction from argparse namespace --------------------------------- - @classmethod def from_namespace(cls, ns) -> "PipelineConfig": """Build a :class:`PipelineConfig` from an ``argparse.Namespace`` @@ -258,10 +321,14 @@ def _pick(dc_class, src: dict): kwargs = {} for group_name, dc_class in _SUB_CONFIG_CLASSES.items(): - if group_name == "model_cfg" and raw.get("model_type", "fmpose3d") == "fmpose3d": + if group_name == "model_cfg" and raw.get("model_type", 'fmpose3d') in _FMPOSE3D_DEFAULTS: dc_class = FMPose3DConfig - elif group_name == "pose2d_cfg" and raw.get("pose2d_model", "hrnet") == "hrnet": - dc_class = HRNetConfig + elif group_name == "pose2d_cfg": + p2d = raw.get("pose2d_model", "hrnet") + if p2d == "superanimal": + dc_class = SuperAnimalConfig + elif p2d == "hrnet": + dc_class = HRNetConfig kwargs[group_name] = _pick(dc_class, raw) return cls(**kwargs) diff --git a/fmpose3d/lib/hrnet/hrnet.py b/fmpose3d/lib/hrnet/hrnet.py index 1368b30..0d0b752 100644 --- a/fmpose3d/lib/hrnet/hrnet.py +++ b/fmpose3d/lib/hrnet/hrnet.py @@ -5,9 +5,7 @@ "FMPose3D: monocular 3D Pose Estimation via Flow Matching" by Ti Wang, Xiaohang Yu, and Mackenzie Weygandt Mathis Licensed under Apache 2.0 -""" -""" FMPose3D – clean HRNet 2D pose estimation API. Provides :class:`HRNetPose2d`, a self-contained wrapper around the diff --git a/fmpose3d/models/__init__.py b/fmpose3d/models/__init__.py index b9dc64a..dd13e64 100644 --- a/fmpose3d/models/__init__.py +++ b/fmpose3d/models/__init__.py @@ -15,6 +15,8 @@ # Import model subpackages so their @register_model decorators execute. from .fmpose3d import Graph, Model +# Import animal models so their @register_model decorators execute. +from fmpose3d.animals import models as _animal_models # noqa: F401 __all__ = [ "BaseModel", diff --git a/scripts/FMPose3D_main.py b/scripts/FMPose3D_main.py index e9adf3a..beb88da 100644 --- a/scripts/FMPose3D_main.py +++ b/scripts/FMPose3D_main.py @@ -41,8 +41,9 @@ spec.loader.exec_module(module) CFM = getattr(module, "Model") else: - # Load model from installed fmpose package - from fmpose3d.models import Model as CFM + # Load model from registered model registry + from fmpose3d.models import get_model + CFM = get_model(args.model_type) def test_multi_hypothesis( @@ -281,12 +282,6 @@ def print_error_action(action_error_sum, is_train): src=script_path, dst=os.path.join(args.checkpoint, args.create_time + "_" + script_name), ) - if getattr(args, "model_path", ""): - model_src_path = os.path.abspath(args.model_path) - model_dst_name = f"{args.create_time}_" + os.path.basename(model_src_path) - shutil.copyfile( - src=model_src_path, dst=os.path.join(args.checkpoint, model_dst_name) - ) sh_base = os.path.basename(args.sh_file) dst_name = f"{args.create_time}_" + sh_base shutil.copyfile(src=args.sh_file, dst=os.path.join(args.checkpoint, dst_name)) diff --git a/scripts/FMPose3D_test.sh b/scripts/FMPose3D_test.sh index 3d2a615..b83d65d 100755 --- a/scripts/FMPose3D_test.sh +++ b/scripts/FMPose3D_test.sh @@ -10,7 +10,7 @@ mode='exp' exp_temp=0.005 folder_name=test_s${eval_multi_steps}_${mode}_h${num_hypothesis_list}_$(date +%Y%m%d_%H%M%S) -model_path='./pre_trained_models/fmpose3d_h36m/model_GAMLP.py' +model_type='fmpose3d' model_weights_path='./pre_trained_models/fmpose3d_h36m/FMpose3D_pretrained_weights.pth' #Test @@ -20,7 +20,7 @@ python3 scripts/FMPose3D_main.py \ --exp_temp ${exp_temp} \ --folder_name ${folder_name} \ --model_weights_path "${model_weights_path}" \ ---model_path "${model_path}" \ +--model_type "${model_type}" \ --eval_sample_steps ${eval_multi_steps} \ --test_augmentation True \ --batch_size ${batch_size} \ diff --git a/scripts/FMPose3D_train.sh b/scripts/FMPose3D_train.sh index 939658c..a677d9c 100755 --- a/scripts/FMPose3D_train.sh +++ b/scripts/FMPose3D_train.sh @@ -11,8 +11,7 @@ epochs=80 num_saved_models=3 frames=1 channel_dim=512 -model_path="" # when the path is empty, the model will be loaded from the installed fmpose3d package -# model_path='./models/model_GAMLP.py' # when the path is not empty, the model will be loaded from the local file path +model_type='fmpose3d' # use registered model by default sh_file='scripts/FMPose3D_train.sh' folder_name=FMPose3D_layers${layers}_$(date +%Y%m%d_%H%M%S) @@ -20,6 +19,7 @@ python3 scripts/FMPose3D_main.py \ --train \ --dataset h36m \ --frames ${frames} \ + --model_type "${model_type}" \ ${model_path:+--model_path "$model_path"} \ --gpu ${gpu_id} \ --batch_size ${batch_size} \