Skip to content

np.nanmean bug in creating triplets for transformer re-ID: query_feature_by_coord_in_img_space() function #2794

@thuann2cats

Description

@thuann2cats

Is there an existing issue for this?

  • I have searched the existing issues

Bug description

Hi DLC team,

I think there might be a bug with these lines (version DLC rc4).
/home/hice1/tnguyen868/scratch/DLC_dnv/lib/python3.10/site-packages/deeplabcut/pose_tracking_pytorch/tracking_utils/preprocessing.py (line 55)

I used this function because I implemented a transformer re-ID training pipeline for the DLC PyTorch engine, since it has only been implemented for DLC TF engine. I mostly reused the functions already available in DeepLabCut PyTorch engine, such as create_triplets_dataset(), create_train_using_pickle(), generate_train_triplets_from_pickle(), save_train_triplets(), etc.

After fixing this bug, I found that my transformer re-ID training accuracies jumped from 50-60% to 95-100% (see below), so I think this bug might be interfering with transformer re-ID pipeline. Here are the details for your own reference and replication.

def query_feature_by_coord_in_img_space(feature_dict, frame_id, ref_coord):
    features = feature_dict[frame_id]["features"]
    coordinates = feature_dict[frame_id]["coordinates"]

    diff = coordinates - ref_coord
    diff[np.where(np.logical_or(diff > 9000, diff < 0))] = np.nan
    match_id = np.argmin(np.nanmean(diff, axis=(1, 2)))

    return features[match_id]

Here’s the values where the bug seems to manifest itself, before getting to this line match_id = … . These are the values that I got in the debug console for the variables in this function. I also showed the shape of some tensors for your convenience.

As you can see, the code is trying to match which animal within coordinates would match the most closely with the points in ref_coord. The answer here, which should be stored in match_id should be 1, instead of 2. Please refer below for my own explanation for your further reference.

ref_coord
array([[                 583,                   19],
       [                 588,                    5],
       [                 576,                    8],
       [                 581,                    5],
       [-9223372036854775808, -9223372036854775808],
       [-9223372036854775808, -9223372036854775808],
       [-9223372036854775808, -9223372036854775808],
       [-9223372036854775808, -9223372036854775808],
       [-9223372036854775808, -9223372036854775808]])
coordinates.shape
(10, 9, 2)
coordinates
array([[[ 1.167093e+03,  4.204960e+02],
        [ 1.160298e+03,  4.386300e+02],
        [ 1.184778e+03,  4.359550e+02],
        [ 1.173923e+03,  4.392800e+02],
        [ 1.180263e+03,  4.664130e+02],
        [ 1.196231e+03,  5.207390e+02],
        [ 1.202038e+03,  6.048250e+02],
        [ 1.167451e+03,  4.706990e+02],
        [ 1.196206e+03,  4.635440e+02]],

       [[ 5.834780e+02,  1.915800e+01],
        [ 5.886130e+02,  5.967000e+00],
        [ 5.766360e+02,  8.133000e+00],
        [ 5.819810e+02,  5.942000e+00],
        [          nan,           nan],
        [          nan,           nan],
        [          nan,           nan],
        [          nan,           nan],
        [          nan,           nan]],

       [[-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00]],

       [[-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00]],

       [[-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00]],

       [[-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00]],

       [[-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00]],

       [[-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00]],

       [[-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00]],

       [[-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00],
        [-1.000000e+00, -1.000000e+00]]], dtype=float32)
diff.shape
(10, 9, 2)
diff
array([[[5.84093018e+02, 4.01496002e+02],
        [5.72297974e+02, 4.33630005e+02],
        [6.08777954e+02, 4.27954987e+02],
        [5.92922974e+02, 4.34279999e+02],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan]],

       [[4.78027344e-01, 1.58000946e-01],
        [6.12976074e-01, 9.67000008e-01],
        [6.35986328e-01, 1.33000374e-01],
        [9.81018066e-01, 9.41999912e-01],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan]],

       [[           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan]],

       [[           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan]],

       [[           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan]],

       [[           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan]],

       [[           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan]],

       [[           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan]],

       [[           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan]],

       [[           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan],
        [           nan,            nan]]])
np.nanmean(diff, axis=(1, 2))  # intermediate value of the line: match_id = ....
<string>:1: RuntimeWarning: Mean of empty slice
array([506.93161392,   0.61350113,          nan,          nan,
                nan,          nan,          nan,          nan,
                nan,          nan])
np.argmin(np.nanmean(diff, axis=(1, 2)))
<string>:1: RuntimeWarning: Mean of empty slice
2

The bug

Many elements in diff are replaced with NaN. If an entire slice of diff (e.g., all elements for a given diff[i] across axes (1, 2)) is NaN, the mean of that slice is undefined.

np.nanmean excludes NaN values, but when a slice contains only NaN values, it cannot compute a valid mean, resulting in:

  • A RuntimeWarning: Mean of empty slice.
  • np.nanmean returning NaN for that slice.

np.argmin is then applied to the result of np.nanmean(diff, axis=(1, 2)), which includes NaN values.
By default, np.argmin does not handle NaN values properly:

  • If the array contains NaN, it will treat the first occurrence of NaN as the minimum and return its index.

Suggested fix

Hence, here’s my fix, which would make sure that the correctly indexed feature is extracted from the list of animals (10 animals in my case).

def query_feature_by_coord_in_img_space(feature_dict, frame_id, ref_coord):
    features = feature_dict[frame_id]["features"]
    coordinates = feature_dict[frame_id]["coordinates"]

    diff = coordinates - ref_coord
    diff[np.where(np.logical_or(diff > 9000, diff < 0))] = np.nan
    # match_id = np.argmin(np.nanmean(diff, axis=(1, 2)))  # existing code
    # THUAN's fix:
    masked_means = np.ma.masked_invalid(np.nanmean(diff, axis=(1, 2)))
    match_id = np.argmin(masked_means)
    return features[match_id]

Results of transformer re-ID training – BEFORE

n_triplets of 1000
Training transformer re-identification model...
Epoch 10, train acc: 0.64
Epoch 10, test acc 0.55
Epoch 20, train acc: 0.64
Epoch 20, test acc 0.55
Epoch 30, train acc: 0.64
Epoch 30, test acc 0.55
Epoch 40, train acc: 0.65
Epoch 40, test acc 0.55
Epoch 50, train acc: 0.64
Epoch 50, test acc 0.54
Epoch 60, train acc: 0.64
Epoch 60, test acc 0.54
Epoch 70, train acc: 0.65
Epoch 70, test acc 0.54
Epoch 80, train acc: 0.65
Epoch 80, test acc 0.54
Epoch 90, train acc: 0.65
Epoch 90, test acc 0.54
Epoch 100, train acc: 0.65
Epoch 100, test acc 0.54
N_triplets of 10000
Training transformer re-identification model...
Epoch 10, train acc: 0.64
Epoch 10, test acc 0.55
Epoch 20, train acc: 0.62
Epoch 20, test acc 0.55
Epoch 30, train acc: 0.63
Epoch 30, test acc 0.55
Epoch 40, train acc: 0.64
Epoch 40, test acc 0.55
Epoch 50, train acc: 0.64
Epoch 50, test acc 0.55
Epoch 60, train acc: 0.64
Epoch 60, test acc 0.55
Epoch 70, train acc: 0.65
Epoch 70, test acc 0.54
Epoch 80, train acc: 0.66
Epoch 80, test acc 0.54
Epoch 90, train acc: 0.64
Epoch 90, test acc 0.54
Epoch 100, train acc: 0.64
Epoch 100, test acc 0.54

Results of transformer re-ID training – AFTER

Training transformer re-identification model...
Epoch 10, train acc: 0.95
Epoch 10, test acc 0.95
Epoch 20, train acc: 0.97
Epoch 20, test acc 0.96
Epoch 30, train acc: 0.97
Epoch 30, test acc 0.96
Epoch 40, train acc: 0.97
Epoch 40, test acc 0.96
Epoch 50, train acc: 0.97
Epoch 50, test acc 0.96
Epoch 60, train acc: 0.97
Epoch 60, test acc 0.96
Epoch 70, train acc: 0.98
Epoch 70, test acc 0.96
Epoch 80, train acc: 0.98
Epoch 80, test acc 0.96
Epoch 90, train acc: 0.97
Epoch 90, test acc 0.96
Epoch 100, train acc: 0.97
Epoch 100, test acc 0.96

Training transformer re-identification model...
Epoch 10, train acc: 0.96
Epoch 10, test acc 0.95
Epoch 20, train acc: 0.99
Epoch 20, test acc 0.98
Epoch 30, train acc: 1.00
Epoch 30, test acc 1.00
Epoch 40, train acc: 1.00
Epoch 40, test acc 1.00
Epoch 50, train acc: 1.00
Epoch 50, test acc 0.99
Epoch 60, train acc: 1.00
Epoch 60, test acc 1.00
Epoch 70, train acc: 1.00
Epoch 70, test acc 1.00
Epoch 80, train acc: 1.00
Epoch 80, test acc 1.00
Epoch 90, train acc: 1.00
Epoch 90, test acc 1.00
Epoch 100, train acc: 1.00
Epoch 100, test acc 1.00

Operating System

operating system rhel

DeepLabCut version

dlc version rc4

DeepLabCut mode

multi animal

Device type

gpu

Steps To Reproduce

please see above

Relevant log output

please see above

Anything else?

please see above

Code of Conduct

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions