Skip to content

Commit 0728f3b

Browse files
author
akropp
committed
Use MakeAlias to slice input data
1 parent b26b13c commit 0728f3b

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

src/mlpack/methods/ann/ffn_impl.hpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -494,24 +494,26 @@ typename MatType::elem_type FFN<
494494
// pass.
495495
networkOutput.set_size(network.OutputSize(), batchSize);
496496

497-
network.Forward(MatType(predictors.cols(begin, begin + batchSize - 1), true), networkOutput);
497+
// alias the batches so we don't copy memory
498+
MatType predictors_batch, responses_batch;
499+
MakeAlias(predictors_batch, predictors.colptr(begin), predictors.n_rows, batchSize);
500+
MakeAlias(responses_batch, responses.colptr(begin), responses.n_rows, batchSize);
498501

499-
const typename MatType::elem_type obj = outputLayer.Forward(networkOutput,
500-
MatType(responses.cols(begin, begin + batchSize - 1), true)) + network.Loss();
502+
network.Forward(predictors_batch, networkOutput);
503+
504+
const typename MatType::elem_type obj = outputLayer.Forward(networkOutput, responses_batch) + network.Loss();
501505

502506
// Now perform the backward pass.
503-
outputLayer.Backward(networkOutput,
504-
MatType(responses.cols(begin, begin + batchSize - 1), true), error);
507+
outputLayer.Backward(networkOutput, responses_batch, error);
505508

506509
// The delta should have the same size as the input.
507510
networkDelta.set_size(predictors.n_rows, batchSize);
508-
network.Backward(MatType(predictors.cols(begin, begin + batchSize - 1), true), error, networkDelta);
511+
network.Backward(predictors_batch, error, networkDelta);
509512

510513
// Now compute the gradients.
511514
// The gradient should have the same size as the parameters.
512515
gradient.set_size(parameters.n_rows, parameters.n_cols);
513-
network.Gradient(MatType(predictors.cols(begin, begin + batchSize - 1), true), error,
514-
gradient);
516+
network.Gradient(predictors_batch, error, gradient);
515517

516518
return obj;
517519
}

0 commit comments

Comments
 (0)