Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Some mxnet ctc_loss bug & feature request #10995

@chinakook

Description

@chinakook

Mxnet ctc_loss has nearly the same source code with baidu's warpctc with little modifications, but it has some bugs.

import mxnet as mx
import numpy as np
import numpy.random as npr

Case 1 - mxnet ctc_loss is all right

batch_size = 1024
seq_len = 35
label_len = 10
num_classes = 60

x = mx.nd.random.uniform(shape=(seq_len, batch_size, num_classes), ctx=mx.gpu(0))
y = npr.randint(0, num_classes, size=(batch_size, label_len))
Y = mx.nd.array(y, ctx=mx.gpu(0)) # float label type

loss = mx.nd.contrib.ctc_loss(data=x, label=Y)
loss = mx.nd.make_loss(loss)
print(loss.asnumpy())

Case 2 - mxnet ctc_loss cannot support integer label types

batch_size = 1024
seq_len = 35
label_len = 10
num_classes = 60

x = mx.nd.random.uniform(shape=(seq_len, batch_size, num_classes), ctx=mx.gpu(0))
y = npr.randint(0, num_classes, size=(batch_size, label_len))
Y = mx.nd.array(y, ctx=mx.gpu(0), dtype=np.int32)

loss = mx.nd.contrib.ctc_loss(data=x, label=Y)
loss = mx.nd.make_loss(loss)
print(loss.asnumpy())

Case 3 - mxnet ctc_loss is slow or will crash when num_classes is big

batch_size = 1024
seq_len = 35
label_len = 10
num_classes = 6000

x = mx.nd.random.uniform(shape=(seq_len, batch_size, num_classes), ctx=mx.gpu(0))
y = npr.randint(0, num_classes, size=(batch_size, label_len))
Y = mx.nd.array(y, ctx=mx.gpu(0), dtype=np.int32)

loss = mx.nd.contrib.ctc_loss(data=x, label=Y)
loss = mx.nd.make_loss(loss)
print(loss.asnumpy())

x = mx.nd.Reshape(x, shape=(-3, -2))
Y = mx.nd.Reshape(Y, shape=(-1,))
loss = mx.nd.WarpCTC(data=x, label=Y, label_length=label_len, input_length=seq_len)
print(loss)

Case 4 - warpctc is all OK with big num_classes and integer types

batch_size = 1024
seq_len = 35
label_len = 10
num_classes = 6000

x = mx.nd.random.uniform(shape=(seq_len, batch_size, num_classes), ctx=mx.gpu(0))
y = npr.randint(0, num_classes, size=(batch_size, label_len))
Y = mx.nd.array(y, ctx=mx.gpu(0), dtype=np.int32)

x = mx.nd.Reshape(x, shape=(-3, -2))
Y = mx.nd.Reshape(Y, shape=(-1,))
loss = mx.nd.WarpCTC(data=x, label=Y, label_length=label_len, input_length=seq_len)
print(loss)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions