|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | + |
| 18 | +# Base.Array related interface |
| 19 | + |
| 20 | +import Base: reshape |
| 21 | + |
| 22 | +""" |
| 23 | + reshape(sym::SymbolicNode, dim; reverse=false, name) |
| 24 | + reshape(sym::SymbolicNode, dim...; reverse=false, name) |
| 25 | +
|
| 26 | +Reshape SymbolicNode operator |
| 27 | +
|
| 28 | +Some dimensions of the shape can take special values from the set |
| 29 | +{0, -1, -2, -3, -4}. |
| 30 | +The significance of each is explained below: |
| 31 | +
|
| 32 | +- `0` copy this dimension from the input to the output shape. |
| 33 | +
|
| 34 | + Example: |
| 35 | +
|
| 36 | + - input shape = (2,3,4), shape = (4,0,2), output shape = (4,3,2) |
| 37 | + - input shape = (2,3,4), shape = (2,0,0), output shape = (2,3,4) |
| 38 | +
|
| 39 | +- `-1` infers the dimension of the output shape by using the remainder of the |
| 40 | + input dimensions keeping the size of the new array same as that of the input |
| 41 | + array. At most one dimension of shape can be -1. |
| 42 | +
|
| 43 | + Example: |
| 44 | +
|
| 45 | + - input shape = (2,3,4), shape = (6,1,-1), output shape = (6,1,4) |
| 46 | + - input shape = (2,3,4), shape = (3,-1,8), output shape = (3,1,8) |
| 47 | + - input shape = (2,3,4), shape=(-1,), output shape = (24,) |
| 48 | +
|
| 49 | +- `-2` copy all/remainder of the input dimensions to the output shape. |
| 50 | +
|
| 51 | + Example: |
| 52 | +
|
| 53 | + - input shape = (2,3,4), shape = (-2,), output shape = (2,3,4) |
| 54 | + - input shape = (2,3,4), shape = (2,-2), output shape = (2,3,4) |
| 55 | + - input shape = (2,3,4), shape = (-2,1,1), output shape = (2,3,4,1,1) |
| 56 | +
|
| 57 | +- `-3` use the product of two consecutive dimensions of the input shape as the |
| 58 | + output dimension. |
| 59 | +
|
| 60 | + Example: |
| 61 | +
|
| 62 | + - input shape = (2,3,4), shape = (-3,4), output shape = (6,4) |
| 63 | + - input shape = (2,3,4,5), shape = (-3,-3), output shape = (6,20) |
| 64 | + - input shape = (2,3,4), shape = (0,-3), output shape = (2,12) |
| 65 | + - input shape = (2,3,4), shape = (-3,-2), output shape = (6,4) |
| 66 | +
|
| 67 | +- `-4` split one dimension of the input into two dimensions passed subsequent |
| 68 | + to -4 in shape (can contain -1). |
| 69 | +
|
| 70 | + Example: |
| 71 | +
|
| 72 | + - input shape = (2,3,4), shape = (-4,1,2,-2), output shape = (1,2,3,4) |
| 73 | + - input shape = (2,3,4), shape = (2,-4,-1,3,-2), output shape = (2,1,3,4) |
| 74 | +
|
| 75 | +If the argument `reverse` is set to `1`, then the special values are inferred |
| 76 | +from right to left. |
| 77 | +
|
| 78 | + Example: |
| 79 | +
|
| 80 | + - with `reverse=false`, for input shape = (10,5,4), shape = (-1,0), |
| 81 | + output shape would be (40,5) |
| 82 | + - with `reverse=true`, output shape will be (50,4). |
| 83 | +""" |
| 84 | +reshape(sym::SymbolicNode, dim::NTuple{N, Integer}; kwargs...) where {N} = |
| 85 | + _reshape(sym, dim; kwargs...) |
| 86 | +reshape(sym::SymbolicNode, dim::Integer...; kwargs...) = |
| 87 | + _reshape(sym, dim; kwargs...) |
| 88 | + |
| 89 | +@inline function _reshape(sym::SymbolicNode, dim::NTuple{N,Integer}; |
| 90 | + reverse::Bool=false, name::String="") where N |
| 91 | + op = _get_cached_libmx_op_handle("reshape") |
| 92 | + node = _create_atomic_symbol(op.value, ["shape", "reverse"], |
| 93 | + [dump_mx_param(dim), dump_mx_param(!reverse)]) |
| 94 | + name = get!(DEFAULT_NAME_MANAGER, name, "reshape") |
| 95 | + _compose!(node, name=name, data=sym) |
| 96 | +end |
| 97 | + |
| 98 | +################################################################################ |
| 99 | +# Base.getindex |
| 100 | +################################################################################ |
| 101 | + |
| 102 | +""" |
| 103 | + getindex(self :: SymbolicNode, idx :: Union{Int, Base.Symbol, AbstractString}) |
| 104 | +
|
| 105 | +Get a node representing the specified output of this node. The index could be |
| 106 | +a symbol or string indicating the name of the output, or a 1-based integer |
| 107 | +indicating the index, as in the list of [`list_outputs`](@ref). |
| 108 | +""" |
| 109 | +function Base.getindex(self :: SymbolicNode, idx :: Union{Base.Symbol, AbstractString}) |
| 110 | + idx = Symbol(idx) |
| 111 | + i_idx = findall(idx .== list_outputs(self)) |
| 112 | + @assert(length(i_idx) > 0, "Cannot find output with name '$idx'") |
| 113 | + @assert(length(i_idx) < 2, "Found duplicated output with name '$idx'") |
| 114 | + Base.getindex(self, i_idx[1]) |
| 115 | +end |
| 116 | +function Base.getindex(self :: SymbolicNode, idx :: Int) |
| 117 | + ref_hdr = Ref{MX_handle}(0) |
| 118 | + # note Julia is 1-based, while MXNet is 0-based |
| 119 | + @mxcall(:MXSymbolGetOutput, (MX_handle, MX_uint, Ref{MX_handle}), self, idx-1, ref_hdr) |
| 120 | + return SymbolicNode(MX_SymbolHandle(ref_hdr[])) |
| 121 | +end |
| 122 | + |
0 commit comments