Skip to content

JAX Attention Refactoring? #1992

@Lime-Cakes

Description

@Lime-Cakes

At the moment, pytorch version of most attention had been refactored to use attention processor (except a few being deprecated). Is there plans to do the same for flax version?

Since AttentionBlock is deprecated, should the same be done to its flax counterpart?

class AttentionBlock(nn.Module):

class FlaxAttentionBlock(nn.Module):

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions