using Nemerle.Collections using Nemerle.Logging using Nemerle.Profiling using Nemerle.Imperative set namespace Fx7 [CreateMemento] \ public class UtvpiTheory : PlainVarTheory #region nested classes [Record] \ internal class Equality // x = by + d public x : Var public b : int public y : Var public d : int public proof : Proof [Record] \ internal class Inequality // ax + by <= d public a : int public x : Var public b : int public y : Var public d : int public proof : Proof public Suicide () : void x.WillWrite () y.WillWrite () x.ineqs = x.ineqs.Filter (_ != this : object) y.ineqs = y.ineqs.Filter (_ != this : object) public Rest (o : Var) : int * int * Var if (o : object == this.x) (a, b, y) else if (o : object == this.y) (b, a, x) else assert (false) public Is (a : int, x : Var, b : int, y : Var) : bool (a == this.a && b == this.b && x : object == this.x && y : object == this.y) || (a == this.b && b == this.a && x : object == this.y && y : object == this.x) public override ToString () : string $ "$a*$x + $b*$y <= $d" [Rollbackable] \ public class Var : BaseVar pool : UtvpiTheory [Copy] internal mutable ineqs : list [Inequality] = [] [Copy] mutable neg_limit : int [Copy] mutable pos_limit : int [Copy] mutable neg_limit_proof : Proof [Copy] mutable pos_limit_proof : Proof public HasIneq : bool get neg_limit_proof != null || pos_limit_proof != null || !(ineqs is []) public HasLimit (a : int) : bool if (a == -1) neg_limit_proof != null else if (a == 1) pos_limit_proof != null else assert (false) public Limit [a : int] : int get if (a == -1) neg_limit else if (a == 1) pos_limit else assert (false) set WillWrite () if (a == -1) neg_limit = value else if (a == 1) pos_limit = value else assert (false) public LimitProof [a : int] : Proof get def res = if (a == -1) neg_limit_proof else if (a == 1) pos_limit_proof else assert (false) assert (res != null) res set WillWrite () if (a == -1) neg_limit_proof = value else if (a == 1) pos_limit_proof = value else assert (false) public IsConst : bool get other : object == pool.zero // this = offset + sign*other [Copy] public mutable offset : int [Copy] public mutable sign : int [Copy] public mutable other : Var // { u | u.other = this } [Copy] public mutable users : list [Var] [Copy] public mutable proof : Proof internal this (p : UtvpiTheory, t : Term) base (t) assert (t != null) pool = p // log (CNT, $"create for id=$(t.Id) $t") offset = 0 sign = 1 other = this proof = Proof.True () users = [this] public override ToString () : string def id = if (term == null) 0 else term.Id if (IsBase) $ "u$id:$term" else $ "u$id[$sign*$(other)+$offset]:$term" public IsBase : bool get { other : object == this && sign == 1 } #endregion zero : Var equalities : Queue [Equality] = Queue () my_symbols : Hashtable [string, bool] = def ht = Hashtable () ["+", "-", "0", "1", "-1"].Iter (ht.Add (_, true)) ht internal mutable current_level : int = 0 [Copy] mutable rollback_queue : list [Var] = [] #region Rollback handling public override PushState () : void current_level++ SaveMemento () assert (equalities.IsEmpty) rollback_queue = [] public override PopState () : void current_level-- foreach (r in rollback_queue) r.Rollback () equalities.Clear () RestoreMemento () QueueRollback (t : Var) : void rollback_queue ::= t #endregion internal this (id : int, c : Core) base (c, id) zero = Var (this, core.TermPool.Get ("0", [])) IsNumber (s : string) : bool if (s == "" || !(char.IsDigit (s [0]) || s [0] == '-')) false else foreach (i in [1 .. s.Length - 1]) when (!char.IsDigit (s [i])) return (false) true public override IsMyFunction (head : string) : bool head != "" && (my_symbols.Contains (head) || ((head [0] == '-' || char.IsDigit (head [0])) && IsNumber (head))) public override IsMyPredicate (head : string) : bool head == "<" || head == ">" || head == "<=" || head == ">=" public override MakeAlien (t : Term, _outer : bool) : void _ = GetVar (t) vars : Hashtable [Term, Var] = Hashtable () GetVar (t : Term) : Var mutable v if (vars.TryGetValue (t, out v)) v else v = if (t.Name == "0") zero else Var (this, t) vars [t] = v when (IsMyFunction (t.Name)) TellStructure (t, t.Name, t.Children) v #region EQ handling // u1 = sign*u2 + offset // u1 is killed DoAssertEq (u1 : Var, sign : int, u2 : Var, offset : int, proof : Proof) : void log (CNT, $ "DoAssertEq: $u1 = $sign*$u2 + $offset proof=$proof") assert (u1.IsBase && u2.IsBase && u1 : object != u2) assert (u1 : object != zero) def sign = if (u2 : object == zero) 1 else sign foreach (ineq in u1.ineqs) ineq.Suicide () def sig (u : Var) (u.sign, u.offset) u2.WillWrite () def by_offset = Hashtable () foreach (u in u2.users) log (CNT, $ "DoAssertEq: by_off[$(sig(u))] = $u") by_offset [sig (u)] = u foreach (u in u1.users) log (CNT, $ "DoAssertEq: handle: $u sign=$(u.sign) offset=$(u.offset)") assert (u.other : object == u1, $"u.oth=$(u.other) u1=$(u1) u2=$(u2)") assert (u : object != u2) u.WillWrite () u2.users ::= u // u = u.sign * u1 + u.offset // u1 = sign * u2 + offset // u = u.sign * (sign * u2 + offset) + u.offset // u = u.sign * sign * u2 + u.sign * offset = u.offset u.other = u2 u.offset += u.sign * offset u.sign *= sign u.proof = Proof.Rule2 ("count-trans1", u.proof, proof) when (u2 : object == zero) u.sign = 1 mutable other = null if (by_offset.TryGetValue (sig (u), out other)) log (CNT, $ "DoAssertEq: prop merge $other = $u") core.TermPool.Merge (other.term, u.term, Proof.Rule2 ("count-trans2", other.proof, u.proof)) // Proof.Rule2 ($"count-trans2->$(other.term)=$(u.term)<-", other.proof, u.proof)) else by_offset [sig (u)] = u // u1 = sign*u2 + offset AssertEq (u1 : Var, sign : int, u2 : Var, offset : int, proof : Proof, use_ineq = true) : void log (CNT, $ "AssertEq: $u1 = $sign*$u2 + $offset proof=$proof") assert (u1 != null && u2 != null) assert (sign == 1 || sign == -1) if (!u1.IsBase || !u2.IsBase) log (CNT, $ "AssertEq: need norm") AssertEq (u1.other, sign * u1.sign * u2.sign, u2.other, u1.sign * (sign * u2.offset + offset - u1.offset), Proof.Rule3 ("count-norm", u1.proof, u2.proof, proof)) else if (u1 : object == u2) log (CNT, $ "AssertEq: same var") if (sign == 1) when (offset != 0) core.Refute (proof) else if (u1 : object == zero) when (offset != 0) core.Refute (proof) else if (offset % 2 != 0) core.Refute (proof) else AssertEq (u1, 1, zero, offset / 2, proof) // we cannot kill zero, // and we don't like killing stuff involved inequalities, as it is more expensive else if (u1 : object == zero || (u2 : object != zero && u1.HasIneq && !u2.HasIneq)) log (CNT, "need swap") AssertEq (u2, sign, u1, -sign * offset, proof) else // if u1 in involved in some inequations, we cannot just kill it if (use_ineq && u1.HasIneq) log (CNT, "AssertEq: have ineq") AddIneq ( 1, u1, -sign, u2, offset, proof) AddIneq (-1, u1, sign, u2, -offset, proof) else // otherwise we can DoAssertEq (u1, sign, u2, offset, proof) #endregion public override AssertPredicate (neg : bool, head : string, children : list [Term], proof : Proof) : void def head = match (head) | _ when !neg => head | ">=" => "<" | "<=" => ">" | ">" => "<=" | "<" => ">=" | _ => head match ((head, children)) | ("<=", [c1, c2]) with off = 0 \ | (">=", [c2, c1]) with off = 0 \ | ("<", [c1, c2]) with off = -1 \ | (">", [c2, c1]) with off = -1 \ => def c1 = GetVar (c1) def c2 = GetVar (c2) log (CNT, $ "AssertPredicate: $neg $head $children by proof=$proof") AddIneq (1, c1, -1, c2, off, proof) | (name, args) => throw FatalError ($ "invalid UTVPI predicate: $name $args") TellStructure (u : Term, head : string, children : list [Term]) : void def u = GetVar (u) match ((head, children)) // u = c1 +- c2 | ("-", [c1, c2]) with sign = -1 \ | ("+", [c1, c2]) with sign = 1 => def c1 = GetVar (c1) def c2 = GetVar (c2) if (c1.IsConst) AssertEq (u, sign, c2, c1.offset, c1.proof) else if (c2.IsConst) AssertEq (u, 1, c1, sign * c2.offset, c2.proof) else throw FatalError ($ "not UTVPI: $u = ($head $c1 $c2)") | ("-", [c1]) \ | ("~", [c1]) => def c1 = GetVar (c1) AssertEq (u, -1, c1, 0, Proof.True ()) | (k, []) when IsNumber (k) => AssertEq (u, 1, zero, int.Parse (k), Proof.True ()) | (name, args) => throw FatalError ($ "invalid UTVPI function: $name $args") public override TellEquality (u1 : Term, u2 : Term, proof : Proof) : void AssertEq (GetVar (u1), 1, GetVar (u2), 0, proof) #region UTVPI DoAddIneq (a : int, x : Var, b : int, y : Var, d : int, proof : Proof) : bool log (CNT, $"DoAddIneq: $a*$x + $b*$y <= $d proof=$proof") when (x.HasLimit (-a) && y.HasLimit (-b) && d == -x.Limit [-a] - y.Limit [-b]) equalities.Push ( Equality (x, -a * b, y, a * d, Proof.Rule3 ("ineq-was3", proof, x.LimitProof [-a], y.LimitProof [-b]))) mutable the_other = null foreach (ineq in x.ineqs) def (a', b', y') = ineq.Rest (x) log (CNT, $"consider: $ineq $(a') $(b') $(y')=$(y)") when (y' : object == y) log (CNT, $"y ok!") if (a' == a && b' == b) log (CNT, $"hit!") assert (the_other == null) the_other = ineq else if (a' == -a && b' == -b && d == -ineq.d) equalities.Push ( Equality (x, -a * b, y, a * d, Proof.Rule2 ("ineq-was", proof, ineq.proof))) else {} if (the_other == null || d < the_other.d) when (the_other != null) the_other.Suicide () log (CNT, $"x.ineqs: $(x.ineqs)") log (CNT, $"y.ineqs: $(y.ineqs)") def ineq = Inequality (a, x, b, y, d, proof) x.WillWrite () y.WillWrite () x.ineqs ::= ineq y.ineqs ::= ineq true else false DoAddIneq (a : int, x : Var, d : int, proof : Proof) : bool log (CNT, $"DoAddIneq: $a*$x <= $d proof=$proof") if (!x.HasLimit (a) || x.Limit [a] > d) x.Limit [a] = d x.LimitProof [a] = proof when (x.HasLimit (-a) && x.Limit [-a] == -d) equalities.Push ( Equality (x, 1, zero, a * d, Proof.Rule2 ("ineq-was2", proof, x.LimitProof [-a]))) true else false AddIneq1 (a : int, x : Var, d : int, proof : Proof) : void assert (a != 0) assert (x.IsBase) if (x.HasLimit (-a) && x.Limit [-a] + d < 0) core.Refute (Proof.Rule2 ("ineq-single", proof, x.LimitProof [-a])) else if (DoAddIneq (a, x, d, proof)) foreach (ineq in x.ineqs) def (a', b, y) = ineq.Rest (x) log (CNT, $"AddIneq1: consider $ineq, rest($x) = ($(a'), $b, $y)") when (-a == a') _ = DoAddIneq (b, y, d + ineq.d, Proof.Rule2 ("ineq-single-t", ineq.proof, proof)) else {} AddIneq2 (a : int, x : Var, b : int, y : Var, d : int, proof : Proof) : void assert (a != 0 && b != 0 && x : object != y) assert (x.IsBase && y.IsBase) // check for contradiction // 1. -ax - by <= d' && d + d' < 0 foreach (ineq in x.ineqs) log (CNT, $"AddIneq2: check $ineq") when (ineq.Is (-a, x, -b, y) && ineq.d + d < 0) core.Refute (Proof.Rule2 ("ineq-double", proof, ineq.proof)) // 2. -ax <= d' && -by <= d'' && d + d' + d'' < 0 when (x.HasLimit (-a) && y.HasLimit (-b) && x.Limit [-a] + y.Limit [-b] + d < 0) core.Refute (Proof.Rule3 ("ineq-2", proof, x.LimitProof [-a], y.LimitProof [-b])) // when in trouble or ineq not needed say bye when (core.Refuted || !DoAddIneq (a, x, b, y, d, proof)) return when (x.HasLimit (-a)) // { by <= d + d' : -ax <= d' } _ = DoAddIneq (b, y, d + x.Limit [-a], Proof.Rule2 ("ineq-4a", proof, x.LimitProof [-a])) when (y.HasLimit (-b)) // { ax <= d + d' : -by <= d' } _ = DoAddIneq (a, x, d + y.Limit [-b], Proof.Rule2 ("ineq-4b", proof, y.LimitProof [-b])) foreach (ineq in x.ineqs) def (a', e, z) = ineq.Rest (x) log (CNT, $"AddIneq: for $ineq [rest($x) = ($(a'), $e, $z)]") if (a' == -a) if (z : object != y) // { by+ez <= d+d' : -ax+ez <= d', z!=y } _ = DoAddIneq (b, y, e, z, d + ineq.d, Proof.Rule2 ("ineq-3a", proof, ineq.proof)) when (y.HasLimit (-b)) // { ez <= d + d' + d'' : -by <= d', -ax + ez <= d'', z!=y } _ = DoAddIneq (e, z, d + ineq.d + y.Limit [-b], Proof.Rule3 ("ineq-4d", proof, ineq.proof, y.LimitProof [-b])) foreach (ineq' in y.ineqs) def (b', f, t) = ineq'.Rest (y) when (b' == -b && t : object != x) if (t : object == z) when (e == f) // { ez <= (d+d'+d'')/2 : -ax+ez <= d' && -by+ez <= d'' } _ = DoAddIneq (e, z, (d + ineq.d + ineq'.d) / 2, Proof.Rule3 ("ineq-4g", proof, ineq.proof, ineq'.proof)) else // { ez + ft <= d + d' + d'' : -ax+ez <= d', z!=y, // -by+ft <= d'', t!=x, t!= z } _ = DoAddIneq (e, z, f, t, d + ineq.d + ineq'.d, Proof.Rule3 ("ineq-3c", proof, ineq.proof, ineq'.proof)) else if (e == b) // { by <= (d+d')/2 : -ax+by <= d' } _ = DoAddIneq (b, y, d + ineq.d, Proof.Rule2 ("ineq-4e", proof, ineq.proof)) else {} else if (a' == a && e == -b && z : object == y) // { ax <= (d+d')/2 : ax - by <= d' } _ = DoAddIneq (a, x, d + ineq.d, Proof.Rule2 ("ineq-4f", proof, ineq.proof)) else {} foreach (ineq in y.ineqs) def (b', f, t) = ineq.Rest (y) when (b' == -b && t : object != x) // { ax + ft <= d + d'' : -by+ft <= d'', t!=x } _ = DoAddIneq (a, x, f, t, d + ineq.d, Proof.Rule2 ("ineq-3b", proof, ineq.proof)) when (x.HasLimit (-a)) // { ft <= d + d' + d'' : -ax <= d', -by + ft <= d'', t!=x } _ = DoAddIneq (f, t, d + ineq.d + x.Limit [-a], Proof.Rule3 ("ineq-4c", proof, ineq.proof, x.LimitProof [-a])) AddIneq (a : int, x : Var, b : int, y : Var, d : int, proof : Proof) : void log (CNT, $"AddIneq: $a*$x + $b*$y <= $d") if (!x.IsBase) AddIneq (a * x.sign, x.other, b, y, d - a * x.offset, Proof.Rule2 ("ineq-norm1", x.proof, proof)) else if (y != null && !y.IsBase) AddIneq (a, x, b * y.sign, y.other, d - b * y.offset, Proof.Rule2 ("ineq-norm2", y.proof, proof)) else if (x : object == y) AddIneq (a + b, x, 0, null, d, proof) else if (y == null && (a > 1 || a < -1)) AddIneq (a / a, x, 0, null, d / a, proof) else if (a == 0 && b == 0) when (d < 0) core.Refute (proof) else if (x : object == zero) if (y != null) AddIneq (b, y, 0, null, d, proof) else when (d < 0) core.Refute (proof) else if (y : object == zero) AddIneq (a, x, 0, null, d, proof) else assert (a == 1 || a == -1, $"a=$a") assert (b == 0 || b == 1 || b == -1) if (y == null) AddIneq1 (a, x, d, proof) else AddIneq2 (a, x, b, y, d, proof) while (!equalities.IsEmpty) def eq = equalities.Take () AssertEq (eq.x, eq.b, eq.y, eq.d, eq.proof, use_ineq = false) #endregion