Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
ff54deb
Remove unused imports
maximpavliv Jul 8, 2025
0963bee
_build_tf_attributes - add pseudo threshold widgets back
maximpavliv Jul 8, 2025
2bb52de
Restore hiding torch_widget when TF Engine is selected
maximpavliv Jul 8, 2025
f0d4f5a
Hide adaptation settings when adaptation is unchecked
maximpavliv Jul 8, 2025
c5306b9
Hide/Show detector row on engine change
maximpavliv Jul 8, 2025
930c71f
Hide detector row when humanbody supermodel is selected
maximpavliv Jul 8, 2025
2438c13
Remove unused method
maximpavliv Jul 8, 2025
0765c8e
Hide torch adaptation widgets when humanbody is selected
maximpavliv Jul 8, 2025
f93530c
video_inference_superanimal: add create_labeled_video argument
maximpavliv Jun 30, 2025
64d663b
Rename method
maximpavliv Jul 8, 2025
6d7e7ef
Add create labeled video checkbox
maximpavliv Jul 8, 2025
207f644
Use set_combo_items() to update combo boxes
maximpavliv Jul 8, 2025
7bf43a0
Add batch size combo boxes for pose model and detector
maximpavliv Jul 8, 2025
3fd8ead
Replace MediaSelectorWidget with VideoSelectorWidget
maximpavliv Jul 9, 2025
a19e980
VideoSelectorWidget: add hide_videotype option
maximpavliv Jul 9, 2025
d50e5b4
Refactor: split _build_common_attributes() into smaller functions
maximpavliv Jul 10, 2025
88d4743
Refactor: split _build_tf_attributes and _build_torch_attributes into…
maximpavliv Jul 10, 2025
1435920
Refactoring: single _add_use_adaptation_row method for tf and torch
maximpavliv Jul 10, 2025
ba3012a
Merge branch 'maxim/restore_video_selection_widget' into maxim/modelz…
maximpavliv Jul 10, 2025
2abf37c
Trim superanimal_humanbody.yaml default project config
maximpavliv Jul 21, 2025
a1a6be1
Trim superanimal_humanbody_colors
maximpavliv Jul 21, 2025
dfbce1d
Correct get_checkpoint_epoch
maximpavliv Jul 21, 2025
1432a73
Add rtmpose_x modelzoo model config
maximpavliv Jul 21, 2025
a4d74cc
Add FilteredDetector
maximpavliv Jul 21, 2025
0cbfe59
Add get_filtered_coco_detector_inference_runner() method
maximpavliv Jul 21, 2025
84b230e
Add ScaleToUnitRange transform
maximpavliv Jul 21, 2025
c4c1318
Superanimal humanbody inference: use filtered detector runner
maximpavliv Jul 21, 2025
dc511cd
ModelZoo tab: make humanbody general case
maximpavliv Jul 21, 2025
479cc66
get_super_animal_scorer(): add torchvision_detector_name arg
maximpavliv Jul 21, 2025
6a14584
Remove superanimal_humanbody_video_inference.py module
maximpavliv Jul 21, 2025
6774812
Regularize get_super_animal_model_config_path()
maximpavliv Jul 21, 2025
424ac92
Regularize load_super_animal_config()
maximpavliv Jul 21, 2025
04438f8
Regularize download_super_animal_snapshot()
maximpavliv Jul 21, 2025
256977b
update_config(): superanimal_humanbody - compatible
maximpavliv Jul 21, 2025
8d4cf20
Revert video_inference()
maximpavliv Jul 21, 2025
edf4f9d
Revert create_df_from_prediction()
maximpavliv Jul 21, 2025
cded581
Restore CTDInferenceRunner
maximpavliv Jul 21, 2025
7eaf923
Remove TorchvisionDetectorInferenceRunner
maximpavliv Jul 21, 2025
feb34d1
Revert DetectorInferenceRunner
maximpavliv Jul 21, 2025
711a47e
superanimal_analyze_images() - make humanbody compatible
maximpavliv Jul 21, 2025
28e4dd0
Revert build_predictions_dataframe()
maximpavliv Jul 21, 2025
193f935
Revert get_inference_runners()
maximpavliv Jul 21, 2025
613b2ac
Revert detectors/fasterRCNN.py
maximpavliv Jul 21, 2025
d848741
Revert detectors/torchvision.py
maximpavliv Jul 21, 2025
bba8689
Revert base Runner
maximpavliv Jul 21, 2025
fa45d26
Merge branch 'maxim/superanimal_humanbody_filtered_detector' into max…
maximpavliv Jul 25, 2025
d8f5bdb
No video adaptation if superanimal_humanbody
maximpavliv Jul 25, 2025
1a03360
Remove deprecated skipped frames msg
maximpavliv Jul 26, 2025
b1d982e
Fix Run button style change
maximpavliv Jul 26, 2025
f77ccda
Universalize created videos msg
maximpavliv Jul 26, 2025
c6ecb0f
Superanimal_humanbody - add detector widget
maximpavliv Jul 26, 2025
25fa08d
Fix superanimal_humanbody unit test
maximpavliv Jul 31, 2025
425484b
Disable video adaptation for superanimal_humanbody
maximpavliv Jul 31, 2025
4b04013
Fix testscript_superanimal_inference.py
maximpavliv Jul 31, 2025
eea470e
Remove debug print
maximpavliv Jul 31, 2025
3d47a36
Black formatting
maximpavliv Jul 31, 2025
fb90004
Merge branch 'maxim/superanimal_humanbody_filtered_detector' into max…
maximpavliv Jul 31, 2025
749de6d
Black formatting
maximpavliv Jul 31, 2025
63d7208
Merge branch 'main' into maxim/modelzoo_tab_fixes
maximpavliv Sep 5, 2025
19a2696
Add superanimal_humanbody video adaptation
maximpavliv Sep 3, 2025
ce0d201
Add superanimal_humanbody video adaptation to GUI
maximpavliv Sep 3, 2025
c39db66
Add warning
maximpavliv Sep 4, 2025
bccf4f3
Raise exception in build_weight_init()
maximpavliv Sep 4, 2025
e9ff1a1
Black formatting
maximpavliv Sep 4, 2025
efca1c7
Read adapt_checkbox for superanimal_humanbody
maximpavliv Sep 10, 2025
c5eeb2e
Merge branch 'main' into maxim/superanimal_humanbody_video_adaptation
maximpavliv Sep 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions deeplabcut/gui/tabs/modelzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,9 @@ def _build_common_attributes(self) -> None:

self.model_combo.currentTextChanged.connect(self._update_pose_models)
self.model_combo.currentTextChanged.connect(self._update_detectors)
self.model_combo.currentTextChanged.connect(self._update_adaptation_visibility)
self.model_combo.currentTextChanged.connect(
self._update_adaptation_detector_visibility
)

def _add_tf_scales_row(self, layout: QtWidgets.QGridLayout):
scales_label = QtWidgets.QLabel("Scale list")
Expand Down Expand Up @@ -359,8 +361,10 @@ def _add_torch_adaptation_settings_row(self, layout: QtWidgets.QGridLayout):
self.torch_adapt_epoch_spinbox.setRange(1, 50)
self.torch_adapt_epoch_spinbox.setValue(4)
self.torch_adapt_epoch_spinbox.setMaximumWidth(100)
adapt_det_epoch_label = QtWidgets.QLabel("Number of detector adaptation epochs")
adapt_det_epoch_label.setMinimumWidth(200)
self.adapt_det_epoch_label = QtWidgets.QLabel(
"Number of detector adaptation epochs"
)
self.adapt_det_epoch_label.setMinimumWidth(200)
self.torch_adapt_det_epoch_spinbox = QtWidgets.QSpinBox()
self.torch_adapt_det_epoch_spinbox.setRange(1, 50)
self.torch_adapt_det_epoch_spinbox.setValue(4)
Expand All @@ -374,7 +378,7 @@ def _add_torch_adaptation_settings_row(self, layout: QtWidgets.QGridLayout):
self.torch_adaptation_settings_row.addWidget(adapt_epoch_label)
self.torch_adaptation_settings_row.addWidget(self.torch_adapt_epoch_spinbox)
self.torch_adaptation_settings_row.addSpacing(20)
self.torch_adaptation_settings_row.addWidget(adapt_det_epoch_label)
self.torch_adaptation_settings_row.addWidget(self.adapt_det_epoch_label)
self.torch_adaptation_settings_row.addWidget(self.torch_adapt_det_epoch_spinbox)
self.torch_adaptation_settings_row.addStretch()
layout.addLayout(self.torch_adaptation_settings_row, 2, 0, 1, 6)
Expand All @@ -401,6 +405,10 @@ def _adapt_checkbox_status_changed(self, state: int) -> None:
set_layout_contents_visible(
self.torch_adaptation_settings_row, Qt.CheckState(state) == Qt.Checked
)
if Qt.CheckState(state) == Qt.Checked:
self._update_adaptation_detector_visibility(
self.model_combo.currentText()
)

def select_folder(self):
dirname = QtWidgets.QFileDialog.getExistingDirectory(
Expand Down Expand Up @@ -563,11 +571,7 @@ def _gather_kwargs(self) -> dict:
kwargs["adapt_iterations"] = self.adapt_iter_spinbox.value()
else:
kwargs["detector_name"] = self.detector_type_selector.currentText()
kwargs["video_adapt"] = (
self.adapt_checkbox.isChecked()
if self.model_combo.currentText() != "superanimal_humanbody"
else False
)
kwargs["video_adapt"] = (self.adapt_checkbox.isChecked())
kwargs["pseudo_threshold"] = self.pose_threshold_spinbox.value()
kwargs["bbox_threshold"] = self.detector_threshold_spinbox.value()
kwargs["detector_epochs"] = self.torch_adapt_det_epoch_spinbox.value()
Expand Down Expand Up @@ -624,14 +628,11 @@ def _update_detectors(self, super_animal: str) -> None:
self.detector_row, self.root.engine == Engine.PYTORCH
)

def _update_adaptation_visibility(self, super_animal: str):
if (
self.root.engine == Engine.PYTORCH
and super_animal != "superanimal_humanbody"
):
self.torch_widget.show()
else:
self.torch_widget.hide()
def _update_adaptation_detector_visibility(self, superanimal: str):
self.adapt_det_epoch_label.setVisible((superanimal != "superanimal_humanbody"))
self.torch_adapt_det_epoch_spinbox.setVisible(
(superanimal != "superanimal_humanbody")
)

@Slot(Engine)
def _on_engine_change(self, engine: Engine) -> None:
Expand Down
102 changes: 58 additions & 44 deletions deeplabcut/modelzoo/video_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,12 +402,6 @@ def video_inference_superanimal(

output_suffix = "_before_adapt"

if superanimal_name == "superanimal_humanbody" and video_adapt:
print(
f"Video adaptation currently not supported for {superanimal_name}. Setting it to false."
)
video_adapt = False

if video_adapt:
# the users can pass in many videos. For now, we only use one video for
# video adaptation. As reported in Ye et al. 2024, one video should be
Expand Down Expand Up @@ -485,60 +479,72 @@ def video_inference_superanimal(
)

model_snapshot_prefix = f"snapshot-{model_name}"
detector_snapshot_prefix = f"snapshot-{detector_name}"

config["runner"]["snapshot_prefix"] = model_snapshot_prefix
config["detector"]["runner"]["snapshot_prefix"] = detector_snapshot_prefix

if superanimal_name != "superanimal_humanbody":
detector_snapshot_prefix = f"snapshot-{detector_name}"
config["detector"]["runner"][
"snapshot_prefix"
] = detector_snapshot_prefix

# the model config's parameters need to be updated for adaptation training
model_config_path = model_folder / "pytorch_config.yaml"
with open(model_config_path, "w") as f:
yaml = YAML()
yaml.dump(config, f)

# get the current epoch of the detector and pose model
# get the current epoch of the pose model
current_pose_epoch = get_checkpoint_epoch(pose_model_path)
current_detector_epoch = get_checkpoint_epoch(detector_path)
# update the checkpoint path with the current epoch, if the checkpoint does not exist, use the best checkpoint
adapted_detector_checkpoint = (
model_folder
/ f"{detector_snapshot_prefix}-{current_detector_epoch + detector_epochs:03}.pt"
)
adapted_pose_checkpoint = (
model_folder
/ f"{model_snapshot_prefix}-{current_pose_epoch + pose_epochs:03}.pt"
)
if not Path(adapted_detector_checkpoint).exists():
adapted_detector_checkpoint = (
model_folder
/ f"{detector_snapshot_prefix}-best-{current_detector_epoch + detector_epochs:03}.pt"
)
if not Path(adapted_pose_checkpoint).exists():
adapted_pose_checkpoint = (
model_folder
/ f"{model_snapshot_prefix}-best-{current_pose_epoch + pose_epochs:03}.pt"
)

if superanimal_name != "superanimal_humanbody":
current_detector_epoch = get_checkpoint_epoch(detector_path)
adapted_detector_checkpoint = (
model_folder
/ f"{detector_snapshot_prefix}-{current_detector_epoch + detector_epochs:03}.pt"
)
if not Path(adapted_detector_checkpoint).exists():
adapted_detector_checkpoint = (
model_folder
/ f"{detector_snapshot_prefix}-best-{current_detector_epoch + detector_epochs:03}.pt"
)

if (
adapted_detector_checkpoint.exists()
and adapted_pose_checkpoint.exists()
):
superanimal_name == "superanimal_humanbody"
or adapted_detector_checkpoint.exists()
) and adapted_pose_checkpoint.exists():
snapshots_msg = f"pose ({adapted_pose_checkpoint})"
if superanimal_name != "superanimal_humanbody":
snapshots_msg += f" and detector ({adapted_detector_checkpoint})"
print(
f"Video adaptation already ran; pose ({adapted_pose_checkpoint}) "
f"and detector ({adapted_detector_checkpoint}) already exist. To "
"rerun video adaptation training, delete the checkpoints or select"
"a different number of adaptation epochs. Continuing with the"
"existing checkpoints."
f"Video adaptation already ran; {snapshots_msg} already exist. "
"To rerun video adaptation training, delete the checkpoints or select a different "
"number of adaptation epochs. Continuing with the existing checkpoints."
)
else:
print(
"Running video adaptation with following parameters:\n"
params_msg = (
f" video adaptation batch size: {video_adapt_batch_size}\n"
f" (pose training) pose_epochs: {pose_epochs}\n"
" (pose) save_epochs: 1\n"
f" detector_epochs: {detector_epochs}\n"
" detector_save_epochs: 1\n"
f" video adaptation batch size: {video_adapt_batch_size}\n"
)
if superanimal_name != "superanimal_humanbody":
params_msg += (
f" detector_epochs: {detector_epochs}\n"
" detector_save_epochs: 1\n"
)
print(
"Running video adaptation with following parameters:\n" + params_msg
)

train_file = pseudo_dataset_folder / "annotations" / "train.json"
with open(train_file, "r") as f:
temp_obj = json.load(f)
Expand All @@ -551,6 +557,11 @@ def video_inference_superanimal(
)
return

if superanimal_name == "superanimal_humanbody":
print(
"Warning, with the superanimal_humanbody type, only the pose model is adapted"
)

adaptation_train(
project_root=pseudo_dataset_folder,
model_folder=model_folder,
Expand All @@ -566,32 +577,35 @@ def video_inference_superanimal(
detector_path=detector_path,
batch_size=video_adapt_batch_size,
detector_batch_size=video_adapt_batch_size,
skip_detector=(superanimal_name == "superanimal_humanbody"),
)

# after video adaptation, re-update the adapted checkpoint path, if the checkpoint does not exist, use the best checkpoint
adapted_detector_checkpoint = (
model_folder
/ f"{detector_snapshot_prefix}-{current_detector_epoch + detector_epochs:03}.pt"
)
adapted_pose_checkpoint = (
model_folder
/ f"{model_snapshot_prefix}-{current_pose_epoch + pose_epochs:03}.pt"
)
if not Path(adapted_detector_checkpoint).exists():
adapted_detector_checkpoint = (
model_folder
/ f"{detector_snapshot_prefix}-best-{current_detector_epoch + detector_epochs:03}.pt"
)
if not Path(adapted_pose_checkpoint).exists():
adapted_pose_checkpoint = (
model_folder
/ f"{model_snapshot_prefix}-best-{current_pose_epoch + pose_epochs:03}.pt"
)
pose_model_path = adapted_pose_checkpoint

if superanimal_name != "superanimal_humanbody":
adapted_detector_checkpoint = (
model_folder
/ f"{detector_snapshot_prefix}-{current_detector_epoch + detector_epochs:03}.pt"
)
if not Path(adapted_detector_checkpoint).exists():
adapted_detector_checkpoint = (
model_folder
/ f"{detector_snapshot_prefix}-best-{current_detector_epoch + detector_epochs:03}.pt"
)
detector_path = adapted_detector_checkpoint

# Set the customized checkpoint paths and
output_suffix = "_after_adapt"
detector_path = adapted_detector_checkpoint
pose_model_path = adapted_pose_checkpoint

return _video_inference_superanimal(
videos,
Expand Down
7 changes: 6 additions & 1 deletion deeplabcut/modelzoo/weight_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from deeplabcut.core.config import read_config_as_dict
from deeplabcut.core.weight_init import WeightInitialization
from deeplabcut.pose_estimation_pytorch.modelzoo.utils import (
get_super_animal_snapshot_path
get_super_animal_snapshot_path,
)


Expand Down Expand Up @@ -71,6 +71,11 @@ def build_weight_init(
Returns:
The built WeightInitialization.
"""
if super_animal == "superanimal_humanbody":
raise NotImplementedError(
"Weight Initialization, Transfer-Learning and Finetuning is currently not supported for superanimal_humanbody"
)

if isinstance(cfg, (str, Path)):
cfg = read_config_as_dict(cfg)

Expand Down
18 changes: 12 additions & 6 deletions deeplabcut/pose_estimation_pytorch/modelzoo/train_from_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def adaptation_train(
batch_size: int = 8,
detector_batch_size: int = 8,
eval_interval: int | None = None,
skip_detector: bool = False,
):
setup_file_logging(Path(model_folder) / "log.txt")
loader = COCOLoader(
Expand All @@ -49,25 +50,30 @@ def adaptation_train(
utils.fix_seeds(loader.model_cfg["train_settings"]["seed"])

updates = {
"detector.model.freeze_bn_stats": True,
"detector.runner.snapshots.max_snapshots": 5,
"detector.runner.snapshots.save_epochs": detector_save_epochs or 1,
"detector.train_settings.batch_size": detector_batch_size,
"detector.train_settings.epochs": detector_epochs or 4,
"model.backbone.freeze_bn_stats": True,
"runner.snapshots.max_snapshots": 5,
"runner.snapshots.save_epochs": save_epochs or 1,
"train_settings.batch_size": batch_size,
"train_settings.epochs": epochs or 4,
}
if not skip_detector:
updates.update(
{
"detector.model.freeze_bn_stats": True,
"detector.runner.snapshots.max_snapshots": 5,
"detector.runner.snapshots.save_epochs": detector_save_epochs or 1,
"detector.train_settings.batch_size": detector_batch_size,
"detector.train_settings.epochs": detector_epochs or 4,
}
)

if eval_interval is not None:
updates["runner.eval_interval"] = eval_interval

loader.update_model_cfg(updates)

pose_task = Task(loader.model_cfg["method"])
if pose_task == Task.TOP_DOWN:
if pose_task == Task.TOP_DOWN and not skip_detector:
logger_config = None
if loader.model_cfg.get("logger"):
logger_config = copy.deepcopy(loader.model_cfg["logger"])
Expand Down
Loading