Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Cannot print parameter summary of embedding layer  #12778

@haven-jeon

Description

@haven-jeon

Description

0 parameter of Embedding layer, when print_summary()

Environment info (Required)

MXNet 1.3.0

Package used (Python/R/Scala/Julia):
I am using python

Minimum reproducible example

class SentClassificationModel(gluon.HybridBlock):
    def __init__(self, vocab_size, num_embed, **kwargs):
        super(SentClassificationModel, self).__init__(**kwargs)
        with self.name_scope():
            self.embed = nn.Embedding(input_dim=vocab_size, output_dim=num_embed)
            self.drop = nn.Dropout(0.3)
            self.fc = nn.Dense(100, activation='relu')
            self.out = nn.Dense(2)  
    def hybrid_forward(self, F ,inputs):
        em_out = self.drop(self.embed(inputs))
        fc_out = self.fc(em_out) 
        return(self.out(fc_out))

ctx = mx.gpu()

model = SentClassificationModel(vocab_size = len(vocab.idx_to_token), num_embed=50)

model.initialize(mx.init.Xavier(),ctx=ctx)
model.hybridize()

mx.viz.print_summary(
    model(mx.sym.var('data')), 
    shape={'data':(1,30)}, #set your shape here
)
________________________________________________________________________________________________________________________
Layer (type)                                        Output Shape            Param #     Previous Layer                  
========================================================================================================================
data(null)                                          30                      0                                           
________________________________________________________________________________________________________________________
sentclassificationmodel0_embedding0_fwd(Embedding)  30x50                   0           data                            
________________________________________________________________________________________________________________________
sentclassificationmodel0_dropout0_fwd(Dropout)      30x50                   0           sentclassificationmodel0_embeddi
________________________________________________________________________________________________________________________
sentclassificationmodel0_dense0_fwd(FullyConnected) 100                     3100        sentclassificationmodel0_dropout
________________________________________________________________________________________________________________________
sentclassificationmodel0_dense0_relu_fwd(Activation)100                     0           sentclassificationmodel0_dense0_
________________________________________________________________________________________________________________________
sentclassificationmodel0_dense1_fwd(FullyConnected) 2                       202         sentclassificationmodel0_dense0_
========================================================================================================================
Total params: 3302
________________________________________________________________________________________________________________________

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions