This repository was archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.7k
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Some mxnet ctc_loss bug & feature request #10995
Copy link
Copy link
Closed
Labels
Description
Mxnet ctc_loss has nearly the same source code with baidu's warpctc with little modifications, but it has some bugs.
import mxnet as mx
import numpy as np
import numpy.random as nprCase 1 - mxnet ctc_loss is all right
batch_size = 1024
seq_len = 35
label_len = 10
num_classes = 60
x = mx.nd.random.uniform(shape=(seq_len, batch_size, num_classes), ctx=mx.gpu(0))
y = npr.randint(0, num_classes, size=(batch_size, label_len))
Y = mx.nd.array(y, ctx=mx.gpu(0)) # float label type
loss = mx.nd.contrib.ctc_loss(data=x, label=Y)
loss = mx.nd.make_loss(loss)
print(loss.asnumpy())Case 2 - mxnet ctc_loss cannot support integer label types
batch_size = 1024
seq_len = 35
label_len = 10
num_classes = 60
x = mx.nd.random.uniform(shape=(seq_len, batch_size, num_classes), ctx=mx.gpu(0))
y = npr.randint(0, num_classes, size=(batch_size, label_len))
Y = mx.nd.array(y, ctx=mx.gpu(0), dtype=np.int32)
loss = mx.nd.contrib.ctc_loss(data=x, label=Y)
loss = mx.nd.make_loss(loss)
print(loss.asnumpy())Case 3 - mxnet ctc_loss is slow or will crash when num_classes is big
batch_size = 1024
seq_len = 35
label_len = 10
num_classes = 6000
x = mx.nd.random.uniform(shape=(seq_len, batch_size, num_classes), ctx=mx.gpu(0))
y = npr.randint(0, num_classes, size=(batch_size, label_len))
Y = mx.nd.array(y, ctx=mx.gpu(0), dtype=np.int32)
loss = mx.nd.contrib.ctc_loss(data=x, label=Y)
loss = mx.nd.make_loss(loss)
print(loss.asnumpy())
x = mx.nd.Reshape(x, shape=(-3, -2))
Y = mx.nd.Reshape(Y, shape=(-1,))
loss = mx.nd.WarpCTC(data=x, label=Y, label_length=label_len, input_length=seq_len)
print(loss)Case 4 - warpctc is all OK with big num_classes and integer types
batch_size = 1024
seq_len = 35
label_len = 10
num_classes = 6000
x = mx.nd.random.uniform(shape=(seq_len, batch_size, num_classes), ctx=mx.gpu(0))
y = npr.randint(0, num_classes, size=(batch_size, label_len))
Y = mx.nd.array(y, ctx=mx.gpu(0), dtype=np.int32)
x = mx.nd.Reshape(x, shape=(-3, -2))
Y = mx.nd.Reshape(Y, shape=(-1,))
loss = mx.nd.WarpCTC(data=x, label=Y, label_length=label_len, input_length=seq_len)
print(loss)