Skip to content

Stale References in RNNBase and Correct Device in Embedding #651

@bmccann

Description

@bmccann

In RNNBase, self.all_weights contains references to all parameters. In the case of using DataParallel, those parameters are replaced by replicate, but the stale references in self.all_weights are sent to cuDNN.

In thnn/sparse.py, Embedding's backward should create grad_weight, _indices, _counts, and _sorted on the same device as grad_output.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions