Skip to content

Commit f15d531

Browse files
authored
refactor: reduce omega's dependency on fvar IDs (#9723)
This PR replaces some `HashSet Expr`-typed collections of facts in `omega`'s implementation with plain lists. This change makes some `omega` calls faster, some slower, but the advantage is that `omega`'s performance is more independent the state of the name generator that produces fvar IDs. I've created this PR for discussion and am happy to hear opinions on whether this should be merged or not. A good reason *not* to merge is that it causes regressions in some places and `grind` is expected to supersede `omega` either way. A good reason to merge is that `omega` is used all over the place and its flaky performance increases the noise in future benchmarks.
1 parent e0fcaf5 commit f15d531

File tree

2 files changed

+30
-37
lines changed

2 files changed

+30
-37
lines changed

src/Lean/Elab/Tactic/Omega/Frontend.lean

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def mkCoordinateEvalAtomsEq (e : Expr) (n : Nat) : OmegaM Expr := do
8686
mkEqTrans eq (← mkEqSymm (mkApp2 (.const ``LinearCombo.coordinate_eval []) n atoms))
8787

8888
/-- Construct the linear combination (and its associated proof and new facts) for an atom. -/
89-
def mkAtomLinearCombo (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
89+
def mkAtomLinearCombo (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × List Expr) := do
9090
let (n, facts) ← lookup e
9191
return ⟨LinearCombo.coordinate n, mkCoordinateEvalAtomsEq e n, facts.getD ∅⟩
9292

@@ -100,7 +100,7 @@ Gives a small (10%) speedup in testing.
100100
I tried using a pointer based cache,
101101
but there was never enough subexpression sharing to make it effective.
102102
-/
103-
partial def asLinearCombo (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
103+
partial def asLinearCombo (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × List Expr) := do
104104
let cache ← get
105105
match cache.get? e with
106106
| some (lc, prf) =>
@@ -126,7 +126,7 @@ We also transform the expression as we descend into it:
126126
* pushing coercions: `↑(x + y)`, `↑(x * y)`, `↑(x / k)`, `↑(x % k)`, `↑k`
127127
* unfolding `emod`: `x % k` → `x - x / k`
128128
-/
129-
partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
129+
partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr × List Expr) := do
130130
trace[omega] "processing {e}"
131131
match groundInt? e with
132132
| some i =>
@@ -148,7 +148,7 @@ partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr ×
148148
mkEqTrans
149149
(← mkAppM ``Int.add_congr #[← prf₁, ← prf₂])
150150
(← mkEqSymm add_eval)
151-
pure (l₁ + l₂, prf, facts₁.union facts₂)
151+
pure (l₁ + l₂, prf, facts₁ ++ facts₂)
152152
| (``HSub.hSub, #[_, _, _, _, e₁, e₂]) => do
153153
let (l₁, prf₁, facts₁) ← asLinearCombo e₁
154154
let (l₂, prf₂, facts₂) ← asLinearCombo e₂
@@ -158,7 +158,7 @@ partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr ×
158158
mkEqTrans
159159
(← mkAppM ``Int.sub_congr #[← prf₁, ← prf₂])
160160
(← mkEqSymm sub_eval)
161-
pure (l₁ - l₂, prf, facts₁.union facts₂)
161+
pure (l₁ - l₂, prf, facts₁ ++ facts₂)
162162
| (``Neg.neg, #[_, _, e']) => do
163163
let (l, prf, facts) ← asLinearCombo e'
164164
let prf' : OmegaM Expr := do
@@ -184,7 +184,7 @@ partial def asLinearComboImpl (e : Expr) : OmegaM (LinearCombo × OmegaM Expr ×
184184
mkEqTrans
185185
(← mkAppM ``Int.mul_congr #[← xprf, ← yprf])
186186
(← mkEqSymm mul_eval)
187-
pure (some (LinearCombo.mul xl yl, prf, xfacts.union yfacts), true)
187+
pure (some (LinearCombo.mul xl yl, prf, xfacts ++ yfacts), true)
188188
else
189189
pure (none, false)
190190
match r? with
@@ -242,15 +242,15 @@ where
242242
Apply a rewrite rule to an expression, and interpret the result as a `LinearCombo`.
243243
(We're not rewriting any subexpressions here, just the top level, for efficiency.)
244244
-/
245-
rewrite (lhs rw : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
245+
rewrite (lhs rw : Expr) : OmegaM (LinearCombo × OmegaM Expr × List Expr) := do
246246
trace[omega] "rewriting {lhs} via {rw} : {← inferType rw}"
247247
match (← inferType rw).eq? with
248248
| some (_, _lhs', rhs) =>
249249
let (lc, prf, facts) ← asLinearCombo rhs
250250
let prf' : OmegaM Expr := do mkEqTrans rw (← prf)
251251
pure (lc, prf', facts)
252252
| none => panic! "Invalid rewrite rule in 'asLinearCombo'"
253-
handleNatCast (e i n : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
253+
handleNatCast (e i n : Expr) : OmegaM (LinearCombo × OmegaM Expr × List Expr) := do
254254
match n with
255255
| .fvar h =>
256256
if let some v ← h.getValue? then
@@ -297,7 +297,7 @@ where
297297
| (``Fin.val, #[n, x]) =>
298298
handleFinVal e i n x
299299
| _ => mkAtomLinearCombo e
300-
handleFinVal (e i n x : Expr) : OmegaM (LinearCombo × OmegaM Expr × Std.HashSet Expr) := do
300+
handleFinVal (e i n x : Expr) : OmegaM (LinearCombo × OmegaM Expr × List Expr) := do
301301
match x with
302302
| .fvar h =>
303303
if let some v ← h.getValue? then
@@ -343,12 +343,11 @@ We solve equalities as they are discovered, as this often results in an earlier
343343
-/
344344
def addIntEquality (p : MetaProblem) (h x : Expr) : OmegaM MetaProblem := do
345345
let (lc, prf, facts) ← asLinearCombo x
346-
let newFacts : Std.HashSet Expr := facts.fold (init := ∅) fun s e =>
347-
if p.processedFacts.contains e then s else s.insert e
346+
let newFacts : List Expr := facts.filter (p.processedFacts.contains · = false)
348347
trace[omega] "Adding proof of {lc} = 0"
349348
pure <|
350349
{ p with
351-
facts := newFacts.toList ++ p.facts
350+
facts := newFacts ++ p.facts
352351
problem := ← (p.problem.addEquality lc.const lc.coeffs
353352
(some do mkEqTrans (← mkEqSymm (← prf)) h)) |>.solveEqualities }
354353

@@ -359,12 +358,11 @@ We solve equalities as they are discovered, as this often results in an earlier
359358
-/
360359
def addIntInequality (p : MetaProblem) (h y : Expr) : OmegaM MetaProblem := do
361360
let (lc, prf, facts) ← asLinearCombo y
362-
let newFacts : Std.HashSet Expr := facts.fold (init := ∅) fun s e =>
363-
if p.processedFacts.contains e then s else s.insert e
361+
let newFacts : List Expr := facts.filter (p.processedFacts.contains · = false)
364362
trace[omega] "Adding proof of {lc} ≥ 0"
365363
pure <|
366364
{ p with
367-
facts := newFacts.toList ++ p.facts
365+
facts := newFacts ++ p.facts
368366
problem := ← (p.problem.addInequality lc.const lc.coeffs
369367
(some do mkAppM ``le_of_le_of_eq #[h, (← prf)])) |>.solveEqualities }
370368

src/Lean/Elab/Tactic/Omega/OmegaM.lean

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,11 @@ def mkEqReflWithExpectedType (a b : Expr) : MetaM Expr := do
168168
Analyzes a newly recorded atom,
169169
returning a collection of interesting facts about it that should be added to the context.
170170
-/
171-
def analyzeAtom (e : Expr) : OmegaM (Std.HashSet Expr) := do
171+
def analyzeAtom (e : Expr) : OmegaM (List Expr) := do
172172
match e.getAppFnArgs with
173173
| (``Nat.cast, #[.const ``Int [], _, e']) =>
174174
-- Casts of natural numbers are non-negative.
175-
let mut r := (∅ : Std.HashSet Expr).insert (Expr.app (.const ``Int.ofNat_nonneg []) e')
175+
let mut r := [Expr.app (.const ``Int.ofNat_nonneg []) e']
176176
match (← cfg).splitNatSub, e'.getAppFnArgs with
177177
| true, (``HSub.hSub, #[_, _, _, _, a, b]) =>
178178
-- `((a - b : Nat) : Int)` gives a dichotomy
@@ -194,9 +194,8 @@ def analyzeAtom (e : Expr) : OmegaM (Std.HashSet Expr) := do
194194
let ne_zero := mkApp3 (.const ``Ne [1]) (.const ``Int []) k (toExpr (0 : Int))
195195
let pos := mkApp4 (.const ``LT.lt [0]) (.const ``Int []) (.const ``Int.instLTInt [])
196196
(toExpr (0 : Int)) k
197-
pure <| (∅ : Std.HashSet Expr).insert
198-
(mkApp3 (.const ``Int.mul_ediv_self_le []) x k (← mkDecideProof ne_zero)) |>.insert
199-
(mkApp3 (.const ``Int.lt_mul_ediv_self_add []) x k (← mkDecideProof pos))
197+
pure [mkApp3 (.const ``Int.mul_ediv_self_le []) x k (← mkDecideProof ne_zero),
198+
mkApp3 (.const ``Int.lt_mul_ediv_self_add []) x k (← mkDecideProof pos)]
200199
| (``HMod.hMod, #[_, _, _, _, x, k]) =>
201200
match k.getAppFnArgs with
202201
| (``HPow.hPow, #[_, _, _, _, b, exp]) => match natCast? b with
@@ -206,10 +205,9 @@ def analyzeAtom (e : Expr) : OmegaM (Std.HashSet Expr) := do
206205
let b_pos := mkApp4 (.const ``LT.lt [0]) (.const ``Int []) (.const ``Int.instLTInt [])
207206
(toExpr (0 : Int)) b
208207
let pow_pos := mkApp3 (.const ``Lean.Omega.Int.pos_pow_of_pos []) b exp (← mkDecideProof b_pos)
209-
pure <| (∅ : Std.HashSet Expr).insert
210-
(mkApp3 (.const ``Int.emod_nonneg []) x k
211-
(mkApp3 (.const ``Int.ne_of_gt []) k (toExpr (0 : Int)) pow_pos)) |>.insert
212-
(mkApp3 (.const ``Int.emod_lt_of_pos []) x k pow_pos)
208+
pure [mkApp3 (.const ``Int.emod_nonneg []) x k
209+
(mkApp3 (.const ``Int.ne_of_gt []) k (toExpr (0 : Int)) pow_pos),
210+
mkApp3 (.const ``Int.emod_lt_of_pos []) x k pow_pos]
213211
| (``Nat.cast, #[.const ``Int [], _, k']) =>
214212
match k'.getAppFnArgs with
215213
| (``HPow.hPow, #[_, _, _, _, b, exp]) => match natCast? b with
@@ -220,28 +218,25 @@ def analyzeAtom (e : Expr) : OmegaM (Std.HashSet Expr) := do
220218
(toExpr (0 : Nat)) b
221219
let pow_pos := mkApp3 (.const ``Nat.pos_pow_of_pos []) b exp (← mkDecideProof b_pos)
222220
let cast_pos := mkApp2 (.const ``Int.ofNat_pos_of_pos []) k' pow_pos
223-
pure <| (∅ : Std.HashSet Expr).insert
224-
(mkApp3 (.const ``Int.emod_nonneg []) x k
225-
(mkApp3 (.const ``Int.ne_of_gt []) k (toExpr (0 : Int)) cast_pos)) |>.insert
226-
(mkApp3 (.const ``Int.emod_lt_of_pos []) x k cast_pos)
221+
pure [mkApp3 (.const ``Int.emod_nonneg []) x k
222+
(mkApp3 (.const ``Int.ne_of_gt []) k (toExpr (0 : Int)) cast_pos),
223+
mkApp3 (.const ``Int.emod_lt_of_pos []) x k cast_pos]
227224
| _ => match x.getAppFnArgs with
228225
| (``Nat.cast, #[.const ``Int [], _, x']) =>
229226
-- Since we push coercions inside `%`, we need to record here that
230227
-- `(x : Int) % (y : Int)` is non-negative.
231-
pure <| (∅ : Std.HashSet Expr).insert (mkApp2 (.const ``Int.emod_ofNat_nonneg []) x' k)
228+
pure [mkApp2 (.const ``Int.emod_ofNat_nonneg []) x' k]
232229
| _ => pure ∅
233230
| _ => pure ∅
234231
| (``Min.min, #[_, _, x, y]) =>
235-
pure <| (∅ : Std.HashSet Expr).insert (mkApp2 (.const ``Int.min_le_left []) x y) |>.insert
236-
(mkApp2 (.const ``Int.min_le_right []) x y)
232+
pure [mkApp2 (.const ``Int.min_le_left []) x y, mkApp2 (.const ``Int.min_le_right []) x y]
237233
| (``Max.max, #[_, _, x, y]) =>
238-
pure <| (∅ : Std.HashSet Expr).insert (mkApp2 (.const ``Int.le_max_left []) x y) |>.insert
239-
(mkApp2 (.const ``Int.le_max_right []) x y)
234+
pure [mkApp2 (.const ``Int.le_max_left []) x y, mkApp2 (.const ``Int.le_max_right []) x y]
240235
| (``ite, #[α, i, dec, t, e]) =>
241236
if α == (.const ``Int []) then
242-
pure <| (∅ : Std.HashSet Expr).insert <| mkApp5 (.const ``ite_disjunction [0]) α i dec t e
237+
pure [mkApp5 (.const ``ite_disjunction [0]) α i dec t e]
243238
else
244-
pure {}
239+
pure []
245240
| _ => pure ∅
246241

247242
/--
@@ -254,7 +249,7 @@ Return its index, and, if it is new, a collection of interesting facts about the
254249
* for each new atom of the form `((a - b : Nat) : Int)`, the fact:
255250
`b ≤ a ∧ ((a - b : Nat) : Int) = a - b ∨ a < b ∧ ((a - b : Nat) : Int) = 0`
256251
-/
257-
def lookup (e : Expr) : OmegaM (Nat × Option (Std.HashSet Expr)) := do
252+
def lookup (e : Expr) : OmegaM (Nat × Option (List Expr)) := do
258253
let c ← getThe State
259254
let e ← canon e
260255
match c.atoms[e]? with
@@ -264,7 +259,7 @@ def lookup (e : Expr) : OmegaM (Nat × Option (Std.HashSet Expr)) := do
264259
let facts ← analyzeAtom e
265260
if ← isTracingEnabledFor `omega then
266261
unless facts.isEmpty do
267-
trace[omega] "New facts: {← facts.toList.mapM fun e => inferType e}"
262+
trace[omega] "New facts: {← facts.mapM fun e => inferType e}"
268263
let i ← modifyGetThe State fun c =>
269264
(c.atoms.size, { c with atoms := c.atoms.insert e c.atoms.size })
270265
return (i, some facts)

0 commit comments

Comments
 (0)