Skip to content

Conversation

@mansoorcheema
Copy link
Contributor

Implemented AdaptiveLogSoftmaxWithLoss and some tests for modules. Reference #25883

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we might be able to do

Suggested change
output = output.index_copy_(0, row_indices, local_logprob.squeeze(1));
output.index_copy_(0, row_indices, local_logprob.squeeze(1));

Copy link
Contributor Author

@mansoorcheema mansoorcheema Nov 14, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was confusing for me as well, as mentioned in another comment .index_copy_() does not update the calling Tensor which apparently it should.

@mansoorcheema
Copy link
Contributor Author

mansoorcheema commented Nov 18, 2019

@yf225 Thanks for the feedback. I have committed the suggested changes except using target mask for indexing (multi-dimensional tensor indexing) which currently is not supported in the C++ API, with the details mentioned inline.

@mansoorcheema mansoorcheema force-pushed the AdaptiveLogSoftmax branch 2 times, most recently from 2dfb76f to 20d1d54 Compare November 28, 2019 12:44
@mansoorcheema
Copy link
Contributor Author

@yf225 I have committed a few updates along with rebasing on latest code. Can you have a look?

@ezyang ezyang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 3, 2020
@yf225 yf225 added the module: cpp Related to C++ API label Mar 11, 2020
@yf225 yf225 force-pushed the AdaptiveLogSoftmax branch from 183206d to a485104 Compare March 12, 2020 00:11
Copy link
Contributor

@yf225 yf225 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for the awesome work! Apologies for the delay as I was waiting for C++ tensor multi-dim indexing to be available, so that we don't need to use narrow / index_select :D

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@yf225 yf225 changed the title Adaptive log softmax AdaptiveLogSoftmaxWithLoss Mar 12, 2020
@yf225 yf225 changed the title AdaptiveLogSoftmaxWithLoss [C++ API] AdaptiveLogSoftmaxWithLoss Mar 12, 2020
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@dr-ci
Copy link

dr-ci bot commented Mar 12, 2020

💊 CircleCI build failures summary and remediations

As of commit 521cd95 (more details on the Dr. CI page):


  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following build failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_xla_linux_xenial_py3_6_clang7_test (1/1)

Step: "Test" (full log | pattern match details) <confirmed not flaky by 2 failures>

Mar 12 03:53:52 ERROR: test_accurracy (__main__.TrainMnist)
Mar 12 03:53:17 2020-03-12 03:53:17.899965: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version 
Mar 12 03:53:17 2020-03-12 03:53:17.904861: I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:300] Initialize GrpcChannelCache for job localservice -> {0 -> localhost:40873} 
Mar 12 03:53:17 2020-03-12 03:53:17.905210: I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:390] Started server with target: grpc://localhost:40873 
Mar 12 03:53:18 2020-03-12 03:53:18.050956: W tensorflow/compiler/jit/xla_device.cc:398] XLA_GPU and XLA_CPU devices are deprecated and will be removed in subsequent releases. Instead, use either @tf.function(experimental_compile=True) for must-compile semantics, or run with TF_XLA_FLAGS=--tf_xla_auto_jit=2 for auto-clustering best-effort compilation. 
Mar 12 03:53:20 Running MNIST Test 
Mar 12 03:53:20 + echo 'Running MNIST Test' 
Mar 12 03:53:20 + python test/test_train_mnist.py --tidy 
Mar 12 03:53:52  0it [00:00, ?it/s]EDownloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /tmp/mnist-data/MNIST/raw/train-images-idx3-ubyte.gz 
Mar 12 03:53:52  
Mar 12 03:53:52 ====================================================================== 
Mar 12 03:53:52 ERROR: test_accurracy (__main__.TrainMnist) 
Mar 12 03:53:52 ---------------------------------------------------------------------- 
Mar 12 03:53:52 Traceback (most recent call last): 
Mar 12 03:53:52   File "test/test_train_mnist.py", line 186, in test_accurracy 
Mar 12 03:53:52     self.assertGreaterEqual(train_mnist(), FLAGS.target_accuracy) 
Mar 12 03:53:52   File "test/test_train_mnist.py", line 74, in train_mnist 
Mar 12 03:53:52     transforms.Normalize((0.1307,), (0.3081,))])) 
Mar 12 03:53:52   File "/var/lib/jenkins/.local/lib/python3.6/site-packages/torchvision/datasets/mnist.py", line 70, in __init__ 
Mar 12 03:53:52     self.download() 
Mar 12 03:53:52   File "/var/lib/jenkins/.local/lib/python3.6/site-packages/torchvision/datasets/mnist.py", line 137, in download 
Mar 12 03:53:52     download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5) 

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

This comment has been revised 14 times.

@facebook-github-bot
Copy link
Contributor

@yf225 merged this pull request in e95657b.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: cpp Related to C++ API open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants