// Based on algorithm from: // Bruno Dutertre, and Leonardo de Moura: A Fast Linear-Arithmetic Solver for DPLL(T) // http://www.csl.sri.com/users/demoura/papers/CAV06/index.html using Nemerle.Collections using Nemerle.Logging using Nemerle.Profiling using Nemerle.Imperative set namespace Fx7 [CreateMemento] \ public class LinearTheory : PlainVarTheory type real = Rational // the list is sorted by Var type LinearTerm = list [real * Var] class VarComparer : System.Collections.Generic.IComparer [Var * real] public Compare (v1 : Var * real, v2 : Var * real) : int v1[0].idx - v2[0].idx var_cmp : VarComparer = VarComparer () [Rollbackable] \ class Var : BaseVar pool : LinearTheory internal mutable idx : int = -1 internal mutable ridx : int = -1 [Copy] internal mutable lower : real = real.NegativeInfinity [Copy] internal mutable lower_proof : Proof [Copy] internal mutable upper : real = real.PositiveInfinity [Copy] internal mutable upper_proof : Proof internal this (p : LinearTheory, t : Term) base (t) pool = p public override ToString () : string if (term == null) $ "U_$idx" else term.ToString () #pragma warning disable 10003 public Desc : string get //$ "$this: [$lower : $(pool.model [idx]) : $upper] ($lower_proof, $upper_proof)" $ "$this: [$lower : $(pool.model [idx]) : $upper]" #pragma warning restore 10003 public IsNonBasic : bool get ridx == -1 public IsBasic : bool get ridx != -1 public ConstProof : Proof get assert (lower == upper) Proof.Rule2 ("linear-const", lower_proof, upper_proof) internal AssertUpper (c : real, p : Proof) : void log (LIN, $"Assert ($this <= $c)") assert (p != null) if (c >= upper) {} else if (c < lower) assert (lower_proof != null) pool.Refute (Proof.Rule2 ("simplex-confl-1", lower_proof, p)) else WillWrite () upper = c upper_proof = p //log (LIN, $"really set-upper $this to $c") when (IsNonBasic && pool.model [idx] > c) pool.Update (this, c) internal AssertLower (c : real, p : Proof) : void log (LIN, $"Assert ($this >= $c)") assert (p != null) if (c <= lower) {} else if (c > upper) assert (upper_proof != null && p != null) pool.Refute (Proof.Rule2 ("simplex-confl-1", upper_proof, p)) else WillWrite () lower = c lower_proof = p //log (LIN, $"really set-lower $this to $c") when (IsNonBasic && pool.model [idx] < c) pool.Update (this, c) #region Rollback handling internal mutable current_level : int = 0 [Copy] mutable rollback_queue : list [Var] = [] public override PushState () : void current_level++ SaveMemento () rollback_queue = [] public override PopState () : void current_level-- foreach (r in rollback_queue) r.Rollback () RestoreMemento () QueueRollback (t : Var) : void rollback_queue ::= t #endregion internal this (id : int, c : Core) base (c, id) one = Var (this, core.TermPool.Get ("1", [])) one.lower = real.One one.upper = real.One one.lower_proof = Proof.True () one.upper_proof = Proof.True () RegisterVar (one) public override IsMyFunction (head : string) : bool match (head) | "" => false | "*" | "+" | "-" => true | _ when head [0] == '-' || char.IsDigit (head [0]) => foreach (i in [1 .. head.Length - 1]) when (!char.IsDigit (head [i])) return (false) true | _ => false static public IsNumber (s : string) : bool if (s == "" || !(char.IsDigit (s [0]) || s [0] == '-')) false else foreach (i in [1 .. s.Length - 1]) when (!char.IsDigit (s [i])) return (false) true public override IsMyPredicate (head : string) : bool head == "<" || head == ">" || head == "<=" || head == ">=" tableau : Vec [Vec [Var * real]] = Vec () vars : Vec [Var] = Vec () model : Vec [real] = Vec () one : Var mutable is_model_shaking : bool mutable model_shaking_proof : Proof public TableauSize : int get tableau.Count public override Clear () : void tableau.Clear () vars.Clear () model.Clear () RegisterVar (one) canon_cache.Clear () var_cache.Clear () row_cache.Clear () non_linears.Clear () const_cache.Clear () RegisterVar (one) one.ridx = -1 Refute (p : Proof) : void if (is_model_shaking) model_shaking_proof = p else core.Refute (p) Tab (x : Var, y : Var) : real def l = tableau [x.ridx] def idx = l.BinarySearch ((y, real.Zero), var_cmp) if (idx < 0) real.Zero else l [idx][1] Basics () : System.Collections.Generic.IEnumerable [Var] foreach (v when v.IsBasic in vars) yield v // eliminate s from row of r Eliminate (s : Var, r : Var) : void def m = Tab (r, s) when (m != 0) def l = tableau [s.ridx] def row = tableau [r.ridx] def res = Vec (l.Count + row.Count - 1) mutable i = 0, j = 0 while (true) if (i < l.Count) if (j < row.Count) def (vl, cl) = l [i] def (vr, cr) = row [j] if (vl.idx < vr.idx) res.Add ((vl, m * cl)) i++ else if (vl.idx > vr.idx) res.Add ((vr, cr)) j++ else def c = cr + m * cl when (c != 0) res.Add ((vr, c)) i++ j++ else while (i < l.Count) def (vl, cl) = l [i] res.Add ((vl, m * cl)) i++ break else if (j < row.Count) while (j < row.Count) def (vr, cr) = row [j] res.Add ((vr, cr)) j++ break else break tableau [r.ridx] = res Pivot (r : Var, s : Var) : void log (LIN, $ "Pivot ($r, $s)") DumpTableau () assert (r.IsBasic) assert (s.IsNonBasic) def mars = -Tab (r, s) assert (Tab (r, r) == -1) assert (mars != 0) def l = tableau [r.ridx] foreach (i in [0 .. l.Count - 1]) def (v, c) = l [i] l [i] = (v, c / mars) r.ridx <-> s.ridx assert (Tab (s, s) == -1) foreach (v when v : object != s in Basics ()) Eliminate (s, v) DumpTableau () Update (x : Var, v : real) : void def d = v - model [x.idx] foreach (y in Basics ()) model [y.idx] += Tab (y, x) * d model [x.idx] = v PivotAndUpdate (x : Var, y : Var, v : real) : void assert (!v.IsInf) assert (Tab (x, y) != 0) def theta = (v - model [x.idx]) / Tab (x, y) model [x.idx] = v model [y.idx] += theta log (LIN, $"PivotAndUpdate (theta=$theta tab=$(Tab(x,y)) v=$v)") foreach (z when z.idx != x.idx in Basics ()) model [z.idx] += Tab (z, y) * theta Pivot (x, y) CheckRow (x : Var, neg : bool) : bool def row = tableau [x.ridx] foreach ((v, r) when v : object != x in row) def r = if (neg) -r else r when ((r > 0 && model [v.idx] < v.upper) || (r < 0 && model [v.idx] > v.lower)) assert (v.IsNonBasic) PivotAndUpdate (x, v, if (neg) x.upper else x.lower) return (true) mutable p = if (neg) x.upper_proof else x.lower_proof assert (p != null) foreach ((v, r) in row) def r = if (neg) -r else r if (r > 0) assert (v.upper_proof != null, $"v is $(v.Desc)") p = Proof.Rule2 ("simplex-1-1", p, v.upper_proof) else if (r < 0) assert (v.lower_proof != null, $"v is $(v.Desc)") p = Proof.Rule2 ("simplex-1-2", p, v.lower_proof) else assert (false) Refute (p) false DoCheck () : bool while (true) def x = min: foreach (v in Basics ()) when (model [v.idx] < v.lower || model [v.idx] > v.upper) min (v) // this shouldn't really happen, but it seems it does foreach (v when v.IsNonBasic in vars) when (model [v.idx] < v.lower || model [v.idx] > v.upper) foreach (q in Basics ()) when (Tab (q, v) != 0) Pivot (q, v) min (v) // oops, it doesn't occur anywhere, we'll stick it into // the allowed range assert (v.lower <= v.upper) model [v.idx] = v.lower return (true) // SAT when (!CheckRow (x, model [x.idx] > x.upper)) return (false) // UNSAT true // never reached DumpTableau () : void whenlogging (LIN) foreach (v when v.IsBasic in vars) def row = tableau [v.ridx] assert (Tab (v, v) == -1, $ "r/v=$(row[v.idx])") def sb = System.Text.StringBuilder () _ = sb.Append ($ "tableau: $v = ") foreach ((x, c) in row) _ = sb.Append ($ " + $c*$x") log (LIN, sb.ToString ()) foreach (v in vars) log (LIN, $ "model: $(v.Desc)") CloneModel () : array [real] def backup = array (model.Count) foreach (i in [0 .. model.Count - 1]) backup [i] = model [i] backup Check () : void def backup = CloneModel () when (!DoCheck ()) foreach (i in [0 .. model.Count - 1]) model [i] = backup [i] DumpTableau () const_cache : Hashtable [real, list [Var]] = Hashtable () var_cache : Hashtable [Term, Var] = Hashtable () [Profile] \ public override MakeAlien (t : Term, outer : bool) : void assert (current_level != 0) log (LIN, $"MakeAlien($t, outer=$outer)") when (!var_cache.Contains (t)) def v = Var (this, t) var_cache [t] = v when (outer) def (lterm, off) = Canonize (t) if (lterm is []) mutable v' if (const_cache.TryGetValue (off, out v')) const_cache [off] ::= v else const_cache [off] = [v] else def lterm = if (off != 0) (off, one) :: lterm else lterm AddRow (v, lterm) log (LIN, $"addrow/alien: $v = $lterm") non_linears : Hashtable [Term, Var] = Hashtable () canon_cache : Hashtable [Term, LinearTerm * real] = Hashtable () Canonize (t : Term) : LinearTerm * real mutable res if (canon_cache.TryGetValue (t, out res)) res else // constant folding? nah... def get_const (t) if (t.Arity == 0 && IsNumber (t.Name)) Some (real.Parse (t.Name)) else if (t.Arity == 1 && t.Name == "~" && IsNumber (t.OnlyChild.Name)) Some (- real.Parse (t.OnlyChild.Name)) else None () mutable add = real.Zero def coeff = Hashtable () def walk (mult : real, t) match ((t.Name, t.Children)) | ("-", [left, right]) with sign = real.MinusOne \ | ("+", [left, right]) with sign = real.One => walk (mult, left) walk (sign * mult, right) | ("~", [left]) \ | ("-", [left]) => walk (-mult, left) | ("*", [left, right]) => match (get_const (left)) | Some (v) => walk (v * mult, right) | None => match (get_const (right)) | Some (v) => walk (v * mult, left) | None => unless (non_linears.Contains (t)) log (WARN, $"Warning: non-linear subterm $t") non_linears [t] = Var (this, core.TermPool.Get ("$$nonLinear", [t])) def v = non_linears [t] if (coeff.Contains (v)) coeff [v] += mult else coeff [v] = mult | _ => match (get_const (t)) | Some (v) => add += v * mult | None => def v = GetVar (t) assert (v != null, $ "no var for $t") if (coeff.Contains (v)) coeff [v] += mult else coeff [v] = mult walk (real.One, t) def lst = $[(c, v) | (Key = v, Value = c) in coeff, c != 0] def sorted = lst.Sort (fun ((_, a), (_, b)) { a.term.Id - b.term.Id }) res = (sorted, add) canon_cache [t] = res res RegisterVar (v : Var) : void v.idx = vars.Count vars.Add (v) model.Add (real.Zero) AddRow (v : Var, t : LinearTerm) : void v.ridx = tableau.Count RegisterVar (v) def row = Vec (t.Length + 1) mutable val = real.Zero foreach ((c, v) in t) when (v.idx == -1) RegisterVar (v) assert (c != 0) row.Add ((v, c)) val += c * model [v.idx] model [v.idx] = val row.Add ((v, real.MinusOne)) row.Sort (var_cmp) // XXX we should probably check if we're not adding the same row twice tableau.Add (row) foreach ((x, _) when x.IsBasic && (v : object != x) in row) Eliminate (x, v) GetVar (t : Term) : Var def t' = t.Var [Id] assert (t' : object == t, $ "t=$t t'=$(t') t.Root=$(t.Root) t'.Root=$(t'.Root)" ) var_cache [t] [Profile] \ ShakeModel () : void def to_check = Hashtable () foreach (v in vars) to_check [v] = v.lower != v.upper def backup = CloneModel () def call_check () match (CallCheck (backup)) | null => foreach (v when to_check [v] in vars) when (model [v.idx] != backup [v.idx]) to_check [v] = false null | proof => proof def check_for (v, r, low) if (low && v.lower == r) v.lower_proof else if (!low && v.upper == r) v.upper_proof else try core.PushState () if (low) v.AssertUpper (r - real.One, Proof.True ()) else v.AssertLower (r + real.One, Proof.True ()) if (model_shaking_proof != null) def p = model_shaking_proof model_shaking_proof = null p // shouldn't really happen else call_check () finally core.PopState () def check_both (v, r) match (check_for (v, r, true)) | null => () | p => v.AssertLower (r, p) match (check_for (v, r, false)) | null => () | p => v.AssertUpper (r, p) def possible (v, r) try core.PushState () v.AssertUpper (r, Proof.True ()) v.AssertLower (r, Proof.True ()) if (model_shaking_proof != null) model_shaking_proof = null false else call_check () == null finally core.PopState () is_model_shaking = true foreach (v when to_check [v] in vars) def current = model [v.idx] if (current.IsInt) check_both (v, current) else def current = current.Floor () if (possible (v, current)) check_both (v, current) else def current = current + real.One if (possible (v, current)) check_both (v, current) else // FIXME: should call Refute here // oops //log (TEMP, $ "Warning: cannot find integer value for $(v.Desc)") {} is_model_shaking = false #if false IntBranch () : void assert (!core.MultiRefutation) Integerize () when (core.Refuted) return def consts = const_cache.Clone () try core.PushState () foreach (v when v.term.Active in vars) mutable lst if (consts.TryGetValue (model [v.idx], out lst)) mutable found = false mutable candidate = null foreach (v' in lst) if (v' : object == v) found = true else if (candidate == null && v'.term.Active) candidate = v' else {} when (! found) consts [model [v.idx]] ::= v core.TermPool.Merge (v.term, v'.term, Proof.Rule2 ($"linear-fake-$(v.idx)-$(v'.idx)", Proof.True (), Proof.True ())) when (core.Refuted) break else consts [model [v.idx]] = [v] unless (core.Refuted) core.FinalTheoryCheck () if (core.Refuted) def problems = Hashtable () def visited = Hashtable () def walk (p : Proof) if (visited.Contains (p)) {} else visited [p] = p match (p) | Rule2 (name, True, True) => when (name.StarsWith ("linear-fake-")) def parts = name.Split ('-') def v1 = int.Parse (parts [2]) def v2 = int.Parse (parts [3]) problems [v1] = vars [v1] problems [v2] = vars [v2] | Rule2 (_, p1, p2) => walk (p1) walk (p2) | Rule3 (_, p1, p2, p3) => walk (p1) walk (p2) walk (p3) | Rule4 (_, p1, p2, p3, p4) => walk (p1) walk (p2) walk (p3) walk (p4) | EqProof (t1, t2) => walk (t1.ProofOfEqualityWith (t2)) | True | Literal => {} walk (core.RefutationProof) core.PopState () core.PushState () def vars_list = $[ v | v in problems.Values ] Narrow (vars_list) mutable min_diff = real.PositiveInfinity mutable min_v = null foreach (v in vars_list) when (v.lower == v.upper) continue def diff = v.upper - v.lower if (diff < real.One) // we cannot fit anything else here anyway {} else if (diff < min_diff) min_diff = diff min_v = v else {} // XXX when (min_diff < max_range) {} // XXX else {} finally core.PopState () #endif [Profile] \ GenerateEqualities (full : bool) : void log (LIN, "GenerateEq") when (full) ShakeModel () // first, we don't want constants to be basic foreach (v in Basics ()) when (v.lower == v.upper) def row = tableau [v.ridx] foreach ((x, _) in row) when (x.idx != v.idx && v.lower != v.upper) Pivot (v, x) break // this is for things like 1+1=2 foreach (vars in const_cache.Values) if (vars is [_]) {} else mutable prev = null foreach (v in vars) when (v.term != null && v.term.Active) if (prev == null) prev = v else core.TermPool.Merge (prev.term, v.term, Proof.True ()) def ht = Hashtable () def add (canon, (var, proof)) log (LIN, $"GenerateEq: $var = $canon") def (poly, off) = canon when (var.term != null && var.term.Active) mutable skip = false when (poly is []) mutable vars when (const_cache.TryGetValue (off, out vars)) mutable other = null foreach (v in vars) when (v.term != null && v.term.Active) when (v : object == var) skip = true break other = v when (!skip && other != null) core.TermPool.Merge (other.term, var.term, proof) skip = true when (!skip) mutable tmp if (ht.TryGetValue (canon, out tmp)) def (var', proof') = tmp log (LIN, $"GenerateEq, got hit: $(var')") core.TermPool.Merge (var.term, var'.term, Proof.Rule2 ("linear-canon", proof, proof')) else ht [canon] = (var, proof) foreach (v in vars) if (v.lower == v.upper) add (([], v.lower), (v, v.ConstProof)) else if (v.IsBasic) mutable off = real.Zero mutable poly = [] mutable proof = Proof.True () def row = tableau [v.ridx] foreach ((x, r) in row) when (x.idx != v.idx) if (x.lower == x.upper) off += r * x.lower proof = Proof.Rule2 ("linear-const-trans", proof, x.ConstProof) else poly ::= (r, x) add ((poly, off), (v, proof)) else add (([(real.One, v)], real.Zero), (v, Proof.True ())) // try to make the model more integer than it is #pragma warning disable 10003 Integerize () : void assert (!is_model_shaking) mutable new_model mutable best_ints = 0 def best = CloneModel () foreach (b in best) when (b.IsInt) best_ints++ foreach (v in vars) def val = model [v.idx] def assert_eq (n) () => v.AssertLower (n, Proof.True ()) v.AssertUpper (n, Proof.True ()) when (! val.IsInt) def found_it () mutable ints = 0 foreach (m in model) when (m.IsInt) ints++ when (ints <= best_ints) best_ints = ints foreach (i in [0 .. best.Length - 1]) best [i] = model [i] def p1 = TryAssert (best, assert_eq (val.Ceil ()), out new_model) if (p1 == null) found_it () else def p2 = TryAssert (best, assert_eq (val.Floor ()), out new_model) if (p2 == null) found_it () else Refute (Proof.Rule2 ("linear-non-int", p1, p2)) static max_range : real = real (5, 1) Narrow (vs : list [Var]) : void def seen_low = CloneModel () def seen_high = CloneModel () def backup = CloneModel () def try_assert (fn) mutable new_model def proof = TryAssert (backup, () => fn (Proof.True ()), out new_model) when (new_model) foreach (v in vs) when (model [v.idx] > seen_high [v.idx]) seen_high [v.idx] = model [v.idx] when (model [v.idx] < seen_low [v.idx]) seen_low [v.idx] = model [v.idx] proof def try_new (v, bound, low) def proof = if (low) try_assert (v.AssertUpper (bound, _)) else try_assert (v.AssertLower (bound, _)) if (proof == null) false else if (low) v.AssertLower (bound, proof) else v.AssertUpper (bound, proof) true foreach (v in vs) log (TEMP, $"narrow: $(v.Desc)") when (v.lower == v.upper) continue when (seen_high [v.idx] - seen_low [v.idx] > max_range) continue when (!try_new (v, seen_high [v.idx] - max_range, true)) continue when (!try_new (v, seen_low [v.idx] + max_range, false)) continue mutable fix = false while (!fix) log (TEMP, $" narrow: $(v.lower) : $(seen_low [v.idx]) : $(model [v.idx]) : $(seen_high [v.idx]) : $(v.upper)") fix = true def diff = v.upper - seen_high [v.idx] when (diff >= real.One) // it will either move v.upper or update seen_high[idx], // so the diff in next iteration will be at least two times smaller _ = try_new (v, seen_high [v.idx] + diff / real.Two, false) fix = false def diff = seen_low [v.idx] - v.lower when (diff >= real.One) _ = try_new (v, seen_low [v.idx] - diff / real.Two, true) fix = false // not really needed, as this won't ever compute range bigger than // 2*max_range when (seen_high [v.idx] - seen_low [v.idx] > max_range) fix = true log (TEMP, $" ->narrow: $(v.lower) : $(seen_low [v.idx]) : $(model [v.idx]) : $(seen_high [v.idx]) : $(v.upper)") #pragma warning restore 10003 CallCheck (backup : array [real]) : Proof if (DoCheck ()) assert (model_shaking_proof == null) null else foreach (i in [0 .. backup.Length - 1]) model [i] = backup [i] def p = model_shaking_proof model_shaking_proof = null assert (p != null) p TryAssert (backup : array [real], fn : void -> void, new_model : out bool) : Proof new_model = false try is_model_shaking = true core.PushState () fn () if (model_shaking_proof != null) def p = model_shaking_proof model_shaking_proof = null p else def p = CallCheck (backup) when (p == null) new_model = true p finally is_model_shaking = false core.PopState () [Profile] \ public override FinalCheckImpl (full : bool) : void Check () when (!core.Refuted) GenerateEqualities (full) row_cache : Hashtable [LinearTerm, Var] = Hashtable () [Profile] \ public override AssertPredicate (neg : bool, head : string, children : list [Term], proof : Proof) : void assert (current_level != 0) def head = match (head) | _ when !neg => head | ">=" => "<" | "<=" => ">" | ">" => "<=" | "<" => ">=" | _ => head log (LIN, $ "$(children.Head) $head $(children.Tail.Head)") match ((head, children)) // == is used only internally | ("==", [c1, c2]) with off = real.Zero \ | ("<=", [c1, c2]) with off = real.Zero \ | (">=", [c2, c1]) with off = real.Zero \ | ("<", [c1, c2]) with off = real.One \ | (">", [c2, c1]) with off = real.One \ => def tmp = core.TermPool.Get ("-", [c1, c2]) def (lterm, add) = Canonize (tmp) // c1 - c2 = lterm + add def add = add + off if (lterm is []) // it was a constant term if (head == "==") if (add.IsZero) {} else core.Refute (proof) else if (add.IsPositive) core.Refute (proof) else {} else def neg_lterm = $[(-c, v) | (c, v) in lterm] mutable res def negated = if (row_cache.TryGetValue (lterm, out res)) false else if (row_cache.TryGetValue (neg_lterm, out res)) true else res = Var (this, null) row_cache [lterm] = res AddRow (res, lterm) log (LIN, $"addrow/pred[$head]: $res = $lterm") false if (negated) res.AssertLower (add, proof) when (head == "==") res.AssertUpper (add, proof) else res.AssertUpper (-add, proof) when (head == "==") res.AssertLower (-add, proof) | (name, args) => throw FatalError ($ "invalid linear predicate: $name $args") public override TellEquality (u1 : Term, u2 : Term, proof : Proof) : void AssertPredicate (false, "==", [u1, u2], proof)