Skip to content

generic loss#163

Merged
lazarusA merged 6 commits intomainfrom
la/generic_loss
Oct 6, 2025
Merged

generic loss#163
lazarusA merged 6 commits intomainfrom
la/generic_loss

Conversation

@lazarusA
Copy link
Member

@lazarusA lazarusA commented Oct 6, 2025

closes #162

@lazarusA
Copy link
Member Author

lazarusA commented Oct 6, 2025

with this now the following should work:

Define your loss

function pinball(ŷ, y, τ)
    r = ŷ .- y
    ρ = τ .* max.(r, 0) .+- 1) .* min.(r, 0)
    return mean(ρ)
end

and then do (positional arguments)

train(...; 
    loss_types=[:mse, :r2, (pinball, (0.1)],
    training_loss=(pinball, (0.1)),
)

as keyword arguments should also work, namely

function pinball(ŷ, y; τ=0.2)
    r = ŷ .- y
    ρ = τ .* max.(r, 0) .+- 1) .* min.(r, 0)
    return mean(ρ)
end

and then

train(...; 
    loss_types=[:mse, :r2, (pinball, (τ=0.2,)],
    training_loss=(pinball, (τ=0.2,)),
)

@lazarusA lazarusA merged commit cf33df1 into main Oct 6, 2025
4 checks passed
@lazarusA lazarusA deleted the la/generic_loss branch January 27, 2026 08:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

do a more generic loss function

2 participants