Skip to content

inline_ofg_expansion clones upstream graph wastefully #2142

@ricardoV94

Description

@ricardoV94

Description

import pytensor
import pytensor.tensor as pt
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import dfs_rewriter
from pytensor.tensor.rewriting.ofg import inline_ofg_expansion

x = pt.scalar("x")
y = pt.exp(x)

inner_y = y.type()
ofg = pytensor.OpFromGraph([inner_y], [pt.cos(inner_y)], inline=True)
z = ofg(y)

fg = FunctionGraph(outputs=[y, z])
assert len(fg.toposort()) == 2

dfs_rewriter(inline_ofg_expansion).rewrite(fg)
assert len(fg.toposort()) == 2, len(fg.toposort())  # AssertionError: 3

We call clone_replace internally, it should be just graph_replace, or of.frozen_graph..unfreeze_graph().replace_all(), no need to do any graph analysis because we know the closure!

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions