Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit ed83071

Browse files
authored
Julia: split symbolic-node.jl into several snippets (#14024)
- `symbolic-node/type.jl` - `symbolic-node/show.jl` - `symbolic-node/arithmetic.jl` - `symbolic-node/io.jl` - `symbolic-node/array.jl` - `symbolic-node/op.jl` - `symbolic-node/autodiff.jl` See also: #14001
1 parent ce9e3cf commit ed83071

File tree

8 files changed

+1121
-980
lines changed

8 files changed

+1121
-980
lines changed

julia/src/symbolic-node.jl

Lines changed: 7 additions & 980 deletions
Large diffs are not rendered by default.
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import Base: +
19+
20+
"""
21+
+(args...)
22+
.+(args...)
23+
24+
Elementwise summation of `SymbolicNode`.
25+
"""
26+
function +(x::SymbolicNode, ys::SymbolicNodeOrReal...)
27+
ret = x
28+
for y ys
29+
if y isa SymbolicNode
30+
ret = _plus(ret, y)
31+
else
32+
ret = _plus_scalar(ret, scalar=MX_float(y))
33+
end
34+
end
35+
ret
36+
end
37+
38+
+(s::Real, x::SymbolicNode, ys::SymbolicNodeOrReal...) = +(x + s, ys...)
39+
40+
broadcasted(::typeof(+), x::SymbolicNode, ys::SymbolicNodeOrReal...) = +(x, ys...)
41+
broadcasted(::typeof(+), s::Real, x::SymbolicNode, ys::SymbolicNodeOrReal...) = +(x + s, ys...)
42+
43+
import Base: -
44+
45+
"""
46+
-(x, y)
47+
.-(x, y)
48+
49+
Elementwise substraction of `SymbolicNode`.
50+
Operating with `Real` is available.
51+
"""
52+
x::SymbolicNode - y::SymbolicNode = _minus(x, y)
53+
x::SymbolicNode - s::Real = _minus_scalar(x, scalar=MX_float(s))
54+
s::Real - x::SymbolicNode = _rminus_scalar(x, scalar=MX_float(s))
55+
56+
-(x::SymbolicNode) = 0 - x
57+
58+
broadcasted(::typeof(-), x::SymbolicNode, y::SymbolicNodeOrReal) = x - y
59+
broadcasted(::typeof(-), s::Real, x::SymbolicNode) = s - x
60+
61+
import Base: *
62+
63+
"""
64+
.*(x, y)
65+
66+
Elementwise multiplication of `SymbolicNode`.
67+
"""
68+
x::SymbolicNode * s::Real = _mul_scalar(x, scalar=MX_float(s))
69+
s::Real * x::SymbolicNode = _mul_scalar(x, scalar=MX_float(s))
70+
71+
function broadcasted(::typeof(*), x::SymbolicNode, ys::SymbolicNodeOrReal...)
72+
ret = x
73+
for y in ys
74+
if y isa SymbolicNode
75+
ret = _mul(ret, y)
76+
else
77+
ret = _mul_scalar(ret, scalar=MX_float(y))
78+
end
79+
end
80+
ret
81+
end
82+
83+
broadcasted(::typeof(*), s::Real, x::SymbolicNode, ys::SymbolicNodeOrReal...) =
84+
broadcasted(*, x * s, ys...)
85+
86+
import Base: /
87+
88+
"""
89+
./(x, y)
90+
91+
* Elementwise dividing a `SymbolicNode` by a scalar or another `SymbolicNode`
92+
of the same shape.
93+
94+
* Elementwise divide a scalar by an `SymbolicNode`.
95+
96+
* Matrix division (solving linear systems) is not implemented yet.
97+
"""
98+
x::SymbolicNode / s::Real = _DivScalar(x, scalar=MX_float(s))
99+
100+
broadcasted(::typeof(/), x::SymbolicNode, y::SymbolicNode) = _div(x, y)
101+
broadcasted(::typeof(/), x::SymbolicNode, s::Real) = _div_scalar(x, scalar=MX_float(s))
102+
broadcasted(::typeof(/), s::Real, x::SymbolicNode) = _rdiv_scalar(x, scalar=MX_float(s))
103+
104+
105+
import Base: ^
106+
107+
"""
108+
.^(x, y)
109+
110+
Elementwise power of `SymbolicNode` and `NDArray`.
111+
Operating with `Real` is available.
112+
"""
113+
^
114+
115+
broadcasted(::typeof(^), x::SymbolicNode, y::SymbolicNode) = _power(x, y)
116+
broadcasted(::typeof(^), x::SymbolicNode, s::Real) = _power_scalar(x, scalar = s)
117+
broadcasted(::typeof(^), s::Real, x::SymbolicNode) = _rpower_scalar(x, scalar = s)
118+
broadcasted(::typeof(Base.literal_pow), ::typeof(^), x::SymbolicNode, ::Val{s}) where {s} =
119+
_power_scalar(x, scalar = s)
120+
121+
broadcasted(::typeof(^), ::Irrational{:ℯ}, x::SymbolicNode) = exp(x)
122+
broadcasted(::typeof(^), x::SymbolicNode, s::Irrational) =
123+
_power_scalar(x, scalar=MX_float(s))
124+
broadcasted(::typeof(^), s::Irrational, x::SymbolicNode) =
125+
_rpower_scalar(x, scalar=MX_float(s))
126+
127+

julia/src/symbolic-node/array.jl

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
# Base.Array related interface
19+
20+
import Base: reshape
21+
22+
"""
23+
reshape(sym::SymbolicNode, dim; reverse=false, name)
24+
reshape(sym::SymbolicNode, dim...; reverse=false, name)
25+
26+
Reshape SymbolicNode operator
27+
28+
Some dimensions of the shape can take special values from the set
29+
{0, -1, -2, -3, -4}.
30+
The significance of each is explained below:
31+
32+
- `0` copy this dimension from the input to the output shape.
33+
34+
Example:
35+
36+
- input shape = (2,3,4), shape = (4,0,2), output shape = (4,3,2)
37+
- input shape = (2,3,4), shape = (2,0,0), output shape = (2,3,4)
38+
39+
- `-1` infers the dimension of the output shape by using the remainder of the
40+
input dimensions keeping the size of the new array same as that of the input
41+
array. At most one dimension of shape can be -1.
42+
43+
Example:
44+
45+
- input shape = (2,3,4), shape = (6,1,-1), output shape = (6,1,4)
46+
- input shape = (2,3,4), shape = (3,-1,8), output shape = (3,1,8)
47+
- input shape = (2,3,4), shape=(-1,), output shape = (24,)
48+
49+
- `-2` copy all/remainder of the input dimensions to the output shape.
50+
51+
Example:
52+
53+
- input shape = (2,3,4), shape = (-2,), output shape = (2,3,4)
54+
- input shape = (2,3,4), shape = (2,-2), output shape = (2,3,4)
55+
- input shape = (2,3,4), shape = (-2,1,1), output shape = (2,3,4,1,1)
56+
57+
- `-3` use the product of two consecutive dimensions of the input shape as the
58+
output dimension.
59+
60+
Example:
61+
62+
- input shape = (2,3,4), shape = (-3,4), output shape = (6,4)
63+
- input shape = (2,3,4,5), shape = (-3,-3), output shape = (6,20)
64+
- input shape = (2,3,4), shape = (0,-3), output shape = (2,12)
65+
- input shape = (2,3,4), shape = (-3,-2), output shape = (6,4)
66+
67+
- `-4` split one dimension of the input into two dimensions passed subsequent
68+
to -4 in shape (can contain -1).
69+
70+
Example:
71+
72+
- input shape = (2,3,4), shape = (-4,1,2,-2), output shape = (1,2,3,4)
73+
- input shape = (2,3,4), shape = (2,-4,-1,3,-2), output shape = (2,1,3,4)
74+
75+
If the argument `reverse` is set to `1`, then the special values are inferred
76+
from right to left.
77+
78+
Example:
79+
80+
- with `reverse=false`, for input shape = (10,5,4), shape = (-1,0),
81+
output shape would be (40,5)
82+
- with `reverse=true`, output shape will be (50,4).
83+
"""
84+
reshape(sym::SymbolicNode, dim::NTuple{N, Integer}; kwargs...) where {N} =
85+
_reshape(sym, dim; kwargs...)
86+
reshape(sym::SymbolicNode, dim::Integer...; kwargs...) =
87+
_reshape(sym, dim; kwargs...)
88+
89+
@inline function _reshape(sym::SymbolicNode, dim::NTuple{N,Integer};
90+
reverse::Bool=false, name::String="") where N
91+
op = _get_cached_libmx_op_handle("reshape")
92+
node = _create_atomic_symbol(op.value, ["shape", "reverse"],
93+
[dump_mx_param(dim), dump_mx_param(!reverse)])
94+
name = get!(DEFAULT_NAME_MANAGER, name, "reshape")
95+
_compose!(node, name=name, data=sym)
96+
end
97+
98+
################################################################################
99+
# Base.getindex
100+
################################################################################
101+
102+
"""
103+
getindex(self :: SymbolicNode, idx :: Union{Int, Base.Symbol, AbstractString})
104+
105+
Get a node representing the specified output of this node. The index could be
106+
a symbol or string indicating the name of the output, or a 1-based integer
107+
indicating the index, as in the list of [`list_outputs`](@ref).
108+
"""
109+
function Base.getindex(self :: SymbolicNode, idx :: Union{Base.Symbol, AbstractString})
110+
idx = Symbol(idx)
111+
i_idx = findall(idx .== list_outputs(self))
112+
@assert(length(i_idx) > 0, "Cannot find output with name '$idx'")
113+
@assert(length(i_idx) < 2, "Found duplicated output with name '$idx'")
114+
Base.getindex(self, i_idx[1])
115+
end
116+
function Base.getindex(self :: SymbolicNode, idx :: Int)
117+
ref_hdr = Ref{MX_handle}(0)
118+
# note Julia is 1-based, while MXNet is 0-based
119+
@mxcall(:MXSymbolGetOutput, (MX_handle, MX_uint, Ref{MX_handle}), self, idx-1, ref_hdr)
120+
return SymbolicNode(MX_SymbolHandle(ref_hdr[]))
121+
end
122+

0 commit comments

Comments
 (0)