-
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Description
Is there an existing issue for this?
- I have searched the existing issues
Bug description
Hi DLC team,
I think there might be a bug with these lines (version DLC rc4).
/home/hice1/tnguyen868/scratch/DLC_dnv/lib/python3.10/site-packages/deeplabcut/pose_tracking_pytorch/tracking_utils/preprocessing.py (line 55)
I used this function because I implemented a transformer re-ID training pipeline for the DLC PyTorch engine, since it has only been implemented for DLC TF engine. I mostly reused the functions already available in DeepLabCut PyTorch engine, such as create_triplets_dataset(), create_train_using_pickle(), generate_train_triplets_from_pickle(), save_train_triplets(), etc.
After fixing this bug, I found that my transformer re-ID training accuracies jumped from 50-60% to 95-100% (see below), so I think this bug might be interfering with transformer re-ID pipeline. Here are the details for your own reference and replication.
def query_feature_by_coord_in_img_space(feature_dict, frame_id, ref_coord):
features = feature_dict[frame_id]["features"]
coordinates = feature_dict[frame_id]["coordinates"]
diff = coordinates - ref_coord
diff[np.where(np.logical_or(diff > 9000, diff < 0))] = np.nan
match_id = np.argmin(np.nanmean(diff, axis=(1, 2)))
return features[match_id]
Here’s the values where the bug seems to manifest itself, before getting to this line match_id = … . These are the values that I got in the debug console for the variables in this function. I also showed the shape of some tensors for your convenience.
As you can see, the code is trying to match which animal within coordinates would match the most closely with the points in ref_coord. The answer here, which should be stored in match_id should be 1, instead of 2. Please refer below for my own explanation for your further reference.
ref_coord
array([[ 583, 19],
[ 588, 5],
[ 576, 8],
[ 581, 5],
[-9223372036854775808, -9223372036854775808],
[-9223372036854775808, -9223372036854775808],
[-9223372036854775808, -9223372036854775808],
[-9223372036854775808, -9223372036854775808],
[-9223372036854775808, -9223372036854775808]])
coordinates.shape
(10, 9, 2)
coordinates
array([[[ 1.167093e+03, 4.204960e+02],
[ 1.160298e+03, 4.386300e+02],
[ 1.184778e+03, 4.359550e+02],
[ 1.173923e+03, 4.392800e+02],
[ 1.180263e+03, 4.664130e+02],
[ 1.196231e+03, 5.207390e+02],
[ 1.202038e+03, 6.048250e+02],
[ 1.167451e+03, 4.706990e+02],
[ 1.196206e+03, 4.635440e+02]],
[[ 5.834780e+02, 1.915800e+01],
[ 5.886130e+02, 5.967000e+00],
[ 5.766360e+02, 8.133000e+00],
[ 5.819810e+02, 5.942000e+00],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan]],
[[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00]],
[[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00]],
[[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00]],
[[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00]],
[[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00]],
[[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00]],
[[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00]],
[[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00],
[-1.000000e+00, -1.000000e+00]]], dtype=float32)
diff.shape
(10, 9, 2)
diff
array([[[5.84093018e+02, 4.01496002e+02],
[5.72297974e+02, 4.33630005e+02],
[6.08777954e+02, 4.27954987e+02],
[5.92922974e+02, 4.34279999e+02],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan]],
[[4.78027344e-01, 1.58000946e-01],
[6.12976074e-01, 9.67000008e-01],
[6.35986328e-01, 1.33000374e-01],
[9.81018066e-01, 9.41999912e-01],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan]],
[[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan]],
[[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan]],
[[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan]],
[[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan]],
[[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan]],
[[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan]],
[[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan]],
[[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan],
[ nan, nan]]])
np.nanmean(diff, axis=(1, 2)) # intermediate value of the line: match_id = ....
<string>:1: RuntimeWarning: Mean of empty slice
array([506.93161392, 0.61350113, nan, nan,
nan, nan, nan, nan,
nan, nan])
np.argmin(np.nanmean(diff, axis=(1, 2)))
<string>:1: RuntimeWarning: Mean of empty slice
2
The bug
Many elements in diff are replaced with NaN. If an entire slice of diff (e.g., all elements for a given diff[i] across axes (1, 2)) is NaN, the mean of that slice is undefined.
np.nanmean excludes NaN values, but when a slice contains only NaN values, it cannot compute a valid mean, resulting in:
- A RuntimeWarning: Mean of empty slice.
- np.nanmean returning NaN for that slice.
np.argmin is then applied to the result of np.nanmean(diff, axis=(1, 2)), which includes NaN values.
By default, np.argmin does not handle NaN values properly:
- If the array contains NaN, it will treat the first occurrence of NaN as the minimum and return its index.
Suggested fix
Hence, here’s my fix, which would make sure that the correctly indexed feature is extracted from the list of animals (10 animals in my case).
def query_feature_by_coord_in_img_space(feature_dict, frame_id, ref_coord):
features = feature_dict[frame_id]["features"]
coordinates = feature_dict[frame_id]["coordinates"]
diff = coordinates - ref_coord
diff[np.where(np.logical_or(diff > 9000, diff < 0))] = np.nan
# match_id = np.argmin(np.nanmean(diff, axis=(1, 2))) # existing code
# THUAN's fix:
masked_means = np.ma.masked_invalid(np.nanmean(diff, axis=(1, 2)))
match_id = np.argmin(masked_means)
return features[match_id]
Results of transformer re-ID training – BEFORE
n_triplets of 1000
Training transformer re-identification model...
Epoch 10, train acc: 0.64
Epoch 10, test acc 0.55
Epoch 20, train acc: 0.64
Epoch 20, test acc 0.55
Epoch 30, train acc: 0.64
Epoch 30, test acc 0.55
Epoch 40, train acc: 0.65
Epoch 40, test acc 0.55
Epoch 50, train acc: 0.64
Epoch 50, test acc 0.54
Epoch 60, train acc: 0.64
Epoch 60, test acc 0.54
Epoch 70, train acc: 0.65
Epoch 70, test acc 0.54
Epoch 80, train acc: 0.65
Epoch 80, test acc 0.54
Epoch 90, train acc: 0.65
Epoch 90, test acc 0.54
Epoch 100, train acc: 0.65
Epoch 100, test acc 0.54
N_triplets of 10000
Training transformer re-identification model...
Epoch 10, train acc: 0.64
Epoch 10, test acc 0.55
Epoch 20, train acc: 0.62
Epoch 20, test acc 0.55
Epoch 30, train acc: 0.63
Epoch 30, test acc 0.55
Epoch 40, train acc: 0.64
Epoch 40, test acc 0.55
Epoch 50, train acc: 0.64
Epoch 50, test acc 0.55
Epoch 60, train acc: 0.64
Epoch 60, test acc 0.55
Epoch 70, train acc: 0.65
Epoch 70, test acc 0.54
Epoch 80, train acc: 0.66
Epoch 80, test acc 0.54
Epoch 90, train acc: 0.64
Epoch 90, test acc 0.54
Epoch 100, train acc: 0.64
Epoch 100, test acc 0.54
Results of transformer re-ID training – AFTER
Training transformer re-identification model...
Epoch 10, train acc: 0.95
Epoch 10, test acc 0.95
Epoch 20, train acc: 0.97
Epoch 20, test acc 0.96
Epoch 30, train acc: 0.97
Epoch 30, test acc 0.96
Epoch 40, train acc: 0.97
Epoch 40, test acc 0.96
Epoch 50, train acc: 0.97
Epoch 50, test acc 0.96
Epoch 60, train acc: 0.97
Epoch 60, test acc 0.96
Epoch 70, train acc: 0.98
Epoch 70, test acc 0.96
Epoch 80, train acc: 0.98
Epoch 80, test acc 0.96
Epoch 90, train acc: 0.97
Epoch 90, test acc 0.96
Epoch 100, train acc: 0.97
Epoch 100, test acc 0.96
Training transformer re-identification model...
Epoch 10, train acc: 0.96
Epoch 10, test acc 0.95
Epoch 20, train acc: 0.99
Epoch 20, test acc 0.98
Epoch 30, train acc: 1.00
Epoch 30, test acc 1.00
Epoch 40, train acc: 1.00
Epoch 40, test acc 1.00
Epoch 50, train acc: 1.00
Epoch 50, test acc 0.99
Epoch 60, train acc: 1.00
Epoch 60, test acc 1.00
Epoch 70, train acc: 1.00
Epoch 70, test acc 1.00
Epoch 80, train acc: 1.00
Epoch 80, test acc 1.00
Epoch 90, train acc: 1.00
Epoch 90, test acc 1.00
Epoch 100, train acc: 1.00
Epoch 100, test acc 1.00
Operating System
operating system rhel
DeepLabCut version
dlc version rc4
DeepLabCut mode
multi animal
Device type
gpu
Steps To Reproduce
please see above
Relevant log output
please see aboveAnything else?
please see above
Code of Conduct
- I agree to follow this project's Code of Conduct