Skip to content

[RFC] Memory format (aka layout aka NHWC) support #19092

@dzhulgakov

Description

@dzhulgakov

Problem statement

CNN operators utilize canonical order of tensor dimensions and assign them semantic meaning. For the 2D case in PyTorch today an input to torch.nn.Conv2d has to be a 4d tensor in NCHW order - <batch, channels, width, height>.

For performance reasons, it's often beneficial to reorder dimensions differently so that memory accessed by particular operations is laid out contiguously and locality is better utilized. Most common option is moving dimensions towards the end - NHWC. There can be even more complex memory formats that tile one dimension into blocks, e.g. <N, C/16, H, W, C16>.

Example libraries utilizing it include:

  • cudnn has faster performance on Volta in NHWC
  • fbgemm and qnnpack don't support NCHW.
  • libxsmm does support NCHW but the performance penalty is something like 50% (IIRC).

The challenge is that transforming the dimension order itself is expensive, so in cases when multiple CNNs operations are performed in a row (e.g. conv(relu(conv)))) it's beneficial to transform to the different memory format once, carry out operations and reorder them back.

Thus it's important to make PyTorch aware of different dimensions orders and be able to pass tensors with different memory formats between operations both in eager and JIT mode. Furthermore, it's beneficial to have automatic JIT optimization passes that try to apply heuristics or search techniques to figure out whether changing memory format is beneficial perf-wise and where in the model it makes sense to do it.

We strive to build API capable of representing:

  • Tensor with different memory format (at the beginning, just dimension order) present in PyTorch in Eager and JIT. Blocked layouts are lower priority but still nice.
  • User-exposed APIs for querying and changing memory format
  • Core CNN operations being able to handle input tensors with different memory format and routing to corresponding faster implementation
  • Ability to infer and optimize about memory formats in JIT passes

Terminology: the problem above is often referred to as “layout” (mxnet), “data_format” (tf), “image_format” (keras), “order” (caffe2). We propose to utilize name “memory format” or “memory_format” in PyTorch. The name “layout” is unfortunately taken in PyTorch with values 'strided' vs 'sparse_coo', so that option of naming is not available.

Affected operators

Following operators at minimum should be memory-format-aware. In addition to producing the correct result, they need to deliver best performance from underlying libraries AND preserve memory format of outputs in order to propagate explicitly specified user intent.

  • convolution
  • different kinds of pooling
  • batch norm, layer norm, instance norm (generally, whatever norms)
  • upsampling/interpolation
  • feature dropout
  • softmax to a lesser degree - dimension can be manually specified there, but efficient implementations are present only for implicit nchw layout
  • padding
  • element-wise (unary and binary) operations
  • constructors of tensors that inherit memory format, e.g. empty_like.

API and Behavior Changes

Define concept of memory format in PyTorch:

  • Constants like torch.memory_format.channels_first. They don't have specified type and can be arbitrary comparable objects (likely start with enum but in future might be other objects to interop with concept of named tensor)
    • Alternative: use torch.channels_first directly
  • Values are channels_first and channels_last (to allow for fewer constants)
  • For 1D images / 3D tensors the values mean NCW, NWC, for 2D images / 4D tensors - NCHW, NHWC, for 3D images / 5D tensors - NCDHW, NDHWC

Add following methods to Tensor:

  • x.is_contiguous(torch.memory_format.channels_first)
  • x.to(memory_format=torch.memory_format.channels_first)

Note: there's no x.get_memory_format() function for now, only explicit checks - it allows wider range of possible implementations. We might want to add it though.

Tensor semantical layout always stay the same - NCHW! x.size() always returns (n,c,h,w)

Operations preserve memory format behavior:

  • convolution, pooling, etc, (see above) return output in the same memory format as the input and internally dispatch to the best implementation
  • unary element-wise operations preserve same memory format and need to run as fast as on contiguous tensor
  • binary element-wise operations provide some reasonable guarantees on preserving memory format - likely can be defined broader but minimum is:
    • NHWC + scalar → NHWC
    • NHWC + column vector → NHWC
  • backward operations for core CNN ops preserve the same memory format as in forward path. (it might be needed to be enforced explicitly because incoming gradients for the output can be in different memory format)

Memory format is a property of a tensor that is preserved through serialization/deserialization (in case the tensor is a parameter).

Strided implementation

Tensor in PyTorch today have concept of strides that specify how logical tensor is laid out in memory. Specifically each tensor has a strides vector of the same length as sizes. In order to index elements in logical indexing (i1, i2, .., ik) one does dot product with strides and looks up memory at offset + i0*stride0 + i1*stride1 + ... * ik * stridek. Contiguous tensors thus have strides which are reversed cumulative products of sizes. For example 4D tensor with sizes (n,c,h,w) has strides (c*h*w, h*w, w, 1).

Strides can be used to represent different memory formats (that are dimension re-ordering) physically while preserving logical default NCHW order. It gives effective definition of memory format transformation as:

# implementation of x.to(channels_last)
def to_mem_format_nhwc(x):
    return x.permute(0,2,3,1).contiguous().permute(0,3,1,2)

# implementation of x.to(channels_first)
def to_mem_format_nchw(x):
    return x.contiguous()

In NHWC format the strides vector is (c*h*w, 1, c*w, c). Thus in memory buffer the weights are in contiguous order for NHWC.

Strides can be used for testing:

def is_nhwc_contiguous(x):
    return x.permute(0,2,3,1).is_contiguous()

# or alteratively
def is_nhwc_contiguous(x):
    n,c,h,w = x.size() # in any case the sizes remain in NCHW order
    return x.stride() == (c*h*w, 1, c*w, c)

def is_nchw_contiguous(x):
    return x.is_contiguous()
    

# operator implementations can just check contiguity and carry on directly on data pointer
def my_sample_op(x):
    if x.is_contiguous(nhwc):
        float* p = x.data();
        # Do we need to go to c++ here? 
        # can we have an example in python?
        n,c,h,w = x.size()
        # operate on `p` as it's guaranteed to be (n,h,w,c) array
        y=my_nhwc_op(p)
        # Do we need to convert the layout of y?
        
    else:
        # Need to convert x to nhwc layout
        x = x.permute(0,2,3,1).contiguous()
        float *p = x.data();
        # Is this needed?
        y = my_nhwc_op(p)
        return y.permute(0,3,1,2).contiguous()

Pros of this approach:

  • Utilizes existing PyTorch concept of strides without adding new top-level ideas or API parameters
  • Preserves logical behavior of tensor in canonical NCHW order
  • Works for arbitrary reordering of input dimensions
  • Existing serialization routines already preserves strides of tensor
  • Ability to reuse many operations to work on different memory layout

Cons:

  • Calling .contiguous() is equivalent to switching to NCHW and may occur by accident from user or inside one of the ops
    • Explicit audit of operators is needed to ensure they preserve memory format
  • Doesn't work for blocked / tiled formats - a different approach is needed
    • It's possible to consider having adding them as first class citizen in PyTorch, but it's a much bigger change
    • Alternative is to treat them as opaque handles, e.g. MKLDNN tensors
  • Performance characteristics of underlying implementations are less obvious to the end user

Biggest potential problem is with unclear user intent. There's no way to distinguish whether user really wanted different memory format or input tensor just happened to be strided this way. Specifically, it leads to behavior change for the existing operations - today convolution can only produce NCHW-contiguous tensors even if the input is arbitrary strided, in a new world it might recognize the input as NHWC and thus would return NHWC too. It doesn't change semantics but leads to hard-to-debug performance issues. Possible solution might be to tag tensors explicitly with user-specified memory_format flag and only follow this annotation (in addition to strides).

To solve above issue, initial proposal is to introduce “soft” memory format tag on tensor that record the last to(memory_format) call done on tensor. Operators would need to propagate this annotation to the outputs. Annotation is “soft”, so we won't hard-error on mismatching annotations but rather produce warnings in profiling mode.

Operator implementations

Signature of existing operators doesn't change. Operators can do hard-coded dispatch inside the operator to route to faster implementation. If implementation is not available, round-tripping through different memory format is possible. Alternative would be raising an error message.

def maxpool(x: Tensor):
    if x.is_contiguous(torch.layout.NHWC):
        return max_pool_impl_nhwc(x)
    return max_pool_impl_default(x.contiguous())

It's preferred to use a single symbol like 'conv' to refer to the operators in JIT IR instead of creating a separate operators like 'conv_nhwc'. The reason for it is simplicity and keeping IR at the level of semantical representation.

Element-wise operations

We have to ensure that core operations like element-wise preserve memory format and are efficient.

Unary operations can be generically handled by verifying whether a block of memory is “dense” - i.e. whether elements span an area without gaps and each memory location is used exactly once. It can be verified with simple algorithm

def is_dense_format(x):
    p = 1
    for s, d in sorted(zip(x.stride(), x.size())):
        if s != p:
            return False
        p *= d
    return True

def my_unary(x):
    if is_dense_format(x):
        return contig_memory_impl(x.data(), x.numel())
    return default_strided_impl(x)
    
# is_dense_format can be used in implementations of e.g. empty_like too

Performance tooling

For debugging performance we should add support to the profiler for:

  • seeing where in the program actual memory reorderings occur - i.e. track calls to .contiguous()
  • tracking which implementation is invoked
  • issue warnings on memory format changes in e.g. binary ops (where “soft” annotation is useful)

This functionality can be built into an on-demand profiling tool.

Autograd handling

It's logical to expect that backwards pass should run with the same memory format as forward. It won't always happen automatically as incoming gradients might be arbitrary strided. Thus forward pass has to explicitly recognize memory format, store it in autograd closure and apply to the grad tensor before the backwards function.

Possible implementation:

def conv_backward(input, weight, grad_output, grad_weight, grad_input):
  if input.is_contiguous(torch.memory_format.channels_last):
    grad_output = grad_output.to(torch.memory_format.channels_last)
    return conv_backward_nhwc(...)
  else:
    grad_output = grad_output.contiguous()
    return conv_backward_nchw(...)

Representation in JIT

Current proposal is to have:

  • No first-class handling for memory format in type annotations just yet. Instead, we can maintain a lookaside map in necessary shape for passes that manipulate memory format
  • Inference pass (similar to shape_inference) that produces per-Value format annotations
  • Memory format transformation passes (manual or automatic) that find where necessary to(memory_format) calls need to be inserted for optimal performance

For enforcement purposes, we can also utilize statements like assert x.is_contiguous(channels_last).

Note: There's a question of where to store information that particular device has a preferred memory format combination (for example qconv on x86 routes to fbgemm that implements NHWC only). One option is to put it in op registration level, however, memory format annotation feels like more of a side information. We can start by maintaining a global map somewhere in JIT pass that denotes preferred memory formats and associated heuristics. If it gets untidy - we can switch to registration-based mechanism.

Beyond: blocked layouts

As we decide to add more complex packings of tensors, using first-class PyTorch tensor for it might not be plausible because of high implementation cost and complexity. Two alternatives are possible:

  • Opaque representations like custom C type bindings. This is an option to choose for packing in inference where diversity is higher in terms of perf optimizations
  • First-class tensor type like MKLDNNTensor with some (but not all) of the operations bound on this new type

Yet another alternative is to implement native support for blocking/tiling in core PyTorch Tensor class.

Named tensor relation

Existing proposal for NamedTensor is structured as a type-checking mechanism on tensors - at the moment it doesn't assign any semantic meaning to dimension names. Thus the only way to infer meaning of the activation tensor is to continue using predetermined NCHW format. It makes NamedTensor and the current proposals orthogonal.

If we're willing to hard-specify meanings of some names (like “channels”, “widths”), operators can utilize this information to route to faster implementation. It'd be a semantic change though as the input tensors would logically have NHWC (not NCHW as today) memory format.

Prior art

TensorFlow supports both NHWC and NCHW at the operator level, via the data_format parameter; acceptable values are (“NHWC”, “NCHW”) for 4-d inputs, (“NDHWC”, “NCDHW”) for 5-d inputs, or channels_first / channels_last independent of input dimensionality. It is up to the user to handle setting the parameter correctly, i.e. it is not tracked automatically by the tensor.

Caffe2 calls this parameter is called order rather than data_format, but it's still applied at individual operator level explicitly.


Appendix: Other options considered

Litmus question: what does the following code print: tensor_in_nhwc_layout.size(1) - the number of channels (because default is NCHW in PyTorch) or height (because that's what is in NHWC layout at position 1).

Based on this answer several options are possible:

  • Option A - Strides (presented above). Tensor layout is a completely internal representation. Implementation-like it's most conveniently done with strides.
    • .size(1) returns me “channels”, but internal memory is laid out differently
    • pro: doesn't change code of the model, my model can still do dimension arithmetic directly. In fact none of the public API changes
    • cons: in strides implementation many operators call .contiguous() and can accidentally revert the layout back
    • cons: From a user perspective, understanding what the guarantees of the op return are paramount. This IMO eliminates strides-only approaches, because it becomes very difficult to understand the format they your op will be returned in, and there's no API to say “ignore my strides, actually just return the NCHW-contiguous thing.” This is in addition to the limitations above.
  • Option B - Explicit NHWC tensor. User explicitly manipulates tensor that has different dimension order but tensor itself doesn't know anything about it. We'd need some annotation on operator level to figure out what user expects.
    • .size(1) returns “height”
    • pro: no magic and very predictable
    • cons: changing model from one layout to another becomes a complex operation that needs to track all accesses to .size() and .reshape() (or you need to make it explicit in the API?)
  • Option B' - Explicit NHWC tensor with layout flag. Same as above, but we allow to attach annotation to the tensor to mark it's semantic layout that ops consume in their implementation. There's no need in operator level annotation then - an operator can do dispatch based on the layout flag of the inputs.
  • Option C - Named Tensor. (https://docs.google.com/document/d/1ynu3wA2hcjwOtEng04N904gJjEbZWcINXO_ardX6hxc/edit#heading=h.2gbe5xpga3w9)
    • .size(1) returns “height” but we ask people to NOT use this API and instead use .size('channel')
    • pro: very explicit and what user wants
    • con: doesn't solve the transition problem, we'd need to force all code written with layout awareness to use named tensors. If not - the same problems as above apply
  • Option D - Layout is opaque tensor type. Treat NHWC as we treat MKLDNN or SparseTensor - separate tensor type with different DispatchID. It's like Option A but with different tradeoffs on default behavior - non-implemented ops would fail instead of reverting to NCHW.
    • .size(1) still returns “channels”
    • pro: no magic and explicit, separate dispatch allows ops to decide what they want
    • pro/cons: all necessary operators need to be implemented on different layout, if some op is missing, user would get an explicit error that it's not supported
    • cons: we probably would need to ban many operations on it, e.g. views because expected results are hard to predict

Metadata

Metadata

Labels

module: internalsRelated to internal abstractions in c10 and ATenmodule: mkldnnRelated to Intel IDEEP or oneDNN (a.k.a. mkldnn) integrationtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions