-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🚀 The feature, motivation and pitch
Currently, the forward method of EmbeddingBag, when offsets are passed, supports only 1D inputs. Hence, training / inference on mini-batches of data isn't supported with offsets.
Offsets are very useful when training on tabular datasets with "multi-valued" cells, such as movie genres, since we may want to sum / average the embeddings associated with several genres to a single vector. There can also be weighted multi-valued cells, for example, when the multiple values are generated by an auxiliary model, and the weights represent the confidence of the model in its prediction. For example, consider automatic extraction of movie genres from their title and description.
Alternatives
Two possible alternatives:
- Using a regular
torch.nn.Embeddingclass, extract the embedding vectors, multiply by weights manually, and aggregate them. In this case we lose the efficiency of the EmbeddingBag class, which doesn't have to actually create the full embedding tensor. This idea is relevant only if the number of features in each mini-batch item is the same. - Use an EmbeddingBag in our model, decompose the mini-batch to its constituent items, and compute the output of the model for each item using a for-loop.
Additional context
No response
cc @cpuhrsch @jbschlosser @bhosmer @drisspg @mikaylagawarecki