This repository was archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.7k
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Cannot print parameter summary of embedding layer #12778
Copy link
Copy link
Closed
Description
Description
0 parameter of Embedding layer, when print_summary()
Environment info (Required)
MXNet 1.3.0
Package used (Python/R/Scala/Julia):
I am using python
Minimum reproducible example
class SentClassificationModel(gluon.HybridBlock):
def __init__(self, vocab_size, num_embed, **kwargs):
super(SentClassificationModel, self).__init__(**kwargs)
with self.name_scope():
self.embed = nn.Embedding(input_dim=vocab_size, output_dim=num_embed)
self.drop = nn.Dropout(0.3)
self.fc = nn.Dense(100, activation='relu')
self.out = nn.Dense(2)
def hybrid_forward(self, F ,inputs):
em_out = self.drop(self.embed(inputs))
fc_out = self.fc(em_out)
return(self.out(fc_out))
ctx = mx.gpu()
model = SentClassificationModel(vocab_size = len(vocab.idx_to_token), num_embed=50)
model.initialize(mx.init.Xavier(),ctx=ctx)
model.hybridize()
mx.viz.print_summary(
model(mx.sym.var('data')),
shape={'data':(1,30)}, #set your shape here
)
________________________________________________________________________________________________________________________
Layer (type) Output Shape Param # Previous Layer
========================================================================================================================
data(null) 30 0
________________________________________________________________________________________________________________________
sentclassificationmodel0_embedding0_fwd(Embedding) 30x50 0 data
________________________________________________________________________________________________________________________
sentclassificationmodel0_dropout0_fwd(Dropout) 30x50 0 sentclassificationmodel0_embeddi
________________________________________________________________________________________________________________________
sentclassificationmodel0_dense0_fwd(FullyConnected) 100 3100 sentclassificationmodel0_dropout
________________________________________________________________________________________________________________________
sentclassificationmodel0_dense0_relu_fwd(Activation)100 0 sentclassificationmodel0_dense0_
________________________________________________________________________________________________________________________
sentclassificationmodel0_dense1_fwd(FullyConnected) 2 202 sentclassificationmodel0_dense0_
========================================================================================================================
Total params: 3302
________________________________________________________________________________________________________________________