Skip to content

fix: temporal supervision shift in XYZ loss mask uses t+1 instead of t#52

Open
atharrva01 wants to merge 1 commit intoDevoLearn:mainfrom
atharrva01:fix/temporal-supervision-shift
Open

fix: temporal supervision shift in XYZ loss mask uses t+1 instead of t#52
atharrva01 wants to merge 1 commit intoDevoLearn:mainfrom
atharrva01:fix/temporal-supervision-shift

Conversation

@atharrva01
Copy link
Contributor

Description

While digging through the training loop, I found a silent but critical bug in how the XYZ loss mask is constructed in NDP-HNN/train.py.

The mask was using birth_times[c] <= (t + 1), which pulls in cells born at the next time step cells that literally don't exist in the graph the model is currently processing. So at every snapshot t, the MSE loss was forcing the model's output to match coordinates of cells it hasn't observed yet.

No crash, no NaN it just silently trains the model on a shifted signal for every cell division event in the dataset, which in a growing C. elegans embryo is basically every training snapshot.


What was wrong

# Before -  mask includes cells born at t+1 (unborn, not in current graph)
mask_next = torch.tensor(
    [birth_times[c] <= (t + 1) for c in cells],
    dtype=torch.bool, device=device
)

The model processes snapshot t and produces pred_xyz for cells present at t. But mask_next was selecting rows for cells born at t+1, so the loss was minimizing MSE against coordinates the model had no way to observe or predict causally.


Fix

# After -  mask selects only cells alive at current snapshot t
mask_next = torch.tensor(
    [birth_times[c] <= t for c in cells],
    dtype=torch.bool, device=device
)

Single character change. Everything else target_xyz indexing, MSE computation, incidence_bce, the detach pattern — is fine.


Impact

The model was accumulating RNN gradients from a supervision signal that was off by one time step across the entire training run. After the fix, the learned developmental program actually corresponds to real embryonic state at each snapshot. Before this, results would show the model appearing to "anticipate" future cell divisions - which looked biologically interesting but was entirely an artifact of the leaky mask.

Use birth_times[c] <= t instead of <= (t+1) so the XYZ loss
supervises predictions against cells alive at the current snapshot,
not one step ahead.

Signed-off-by: atharrva01 <atharvaborade568@gmail.com>
@atharrva01
Copy link
Contributor Author

hi @devoworm this pr fixes off-by-one in XYZ loss mask was including t+1 cells (unborn at current snapshot) in MSE supervision, silently training the model on future targets it can't observe

@atharrva01 atharrva01 marked this pull request as ready for review March 25, 2026 09:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant