using System.Text using Nemerle.Collections using Nemerle.Utility using Nemerle.Logging using Nemerle.Imperative set namespace Fx7 public class ArithProofs type E = ProofGen.E type Z = int constraints_by_term : Hashtable [Term, list [Constraint]] = Hashtable () crossed_constraints : Hashtable [Constraint * Constraint, bool] = Hashtable () mutable all_constraints : list [Constraint] = [] [Record] \ class Constraint public neg1 : bool public term1 : Term public neg2 : bool public term2 : Term public bound : Z public proof : E public override ToString () : string def neg (n) if (n) "-" else "" $ "$(neg(neg1))$term1 + $(neg(neg2))$term2 <= $bound" public override GetHashCode () : int unchecked (if (term1 != null) term1.Id else 0) + (if (term2 != null) term2.Id * 13 else 0) + bound [Nemerle.OverrideObjectEquals] \ public Equals (other : Constraint) : bool term1 === other.term1 && term2 === other.term2 && bound == other.bound && neg1 == other.neg1 && neg2 == other.neg2 public this (t1 : Term, t2 : Term, bound : Z, proof : E) this.bound = bound this.proof = proof when (t1 != null) if (t1.Name == "~") term1 = t1.OnlyChild neg1 = true else term1 = t1 when (t2 != null) if (t2.Name == "~") term2 = t2.OnlyChild neg2 = true else term2 = t2 get_lit : Literal -> E pool : Term.Pool zero : Term one : Term mutable broken : bool mutable stop : bool mutable contradiction : E GetConstant (z : Z) : Term pool.Get (z.ToString (), []) Div2 (n : Z) : Z n >> 1 TryCross (c1 : Constraint, c2 : Constraint) : void when (!crossed_constraints.Contains ((c1, c2))) crossed_constraints [(c1, c2)] = true crossed_constraints [(c2, c1)] = true def trans (n1, t1, n2, t2, c1, p1, n3, t3, n4, t4, c2, p2) if (t1 === t3 && n1 != n3) if (n3) if (t2 != null && t2 === t4) if (n2 == n4) SaveConstraint (Constraint (n2, t2, false, null, Div2 (c1 + c2), E.App ("utvpi_tight", [p1, p2]))) else def kind = if (n2) 2 else 1 SaveConstraint (Constraint (false, null, false, null, (c1 + c2), E.App ($"utvpi_contr$kind", [p1, p2]))) else SaveConstraint (Constraint (n2, t2, n4, t4, c1 + c2, E.App ("utvpi_trans", [p1, p2]))) else trans (n3, t3, n4, t4, c2, p2, n1, t1, n2, t2, c1, p1) else if ((t1 === t4 && n1 != n4) || (t2 === t4 && n2 != n4)) trans (n1, t1, n2, t2, c1, p1, n4, t4, n3, t3, c2, E.App ("utvpi_swap", [p2])) else if (t2 === t3 && n2 != n3) trans (n3, t3, n4, t4, c2, p2, n1, t1, n2, t2, c1, p1) else {} trans (c1.neg1, c1.term1, c1.neg2, c1.term2, c1.bound, c1.proof, c2.neg1, c2.term1, c2.neg2, c2.term2, c2.bound, c2.proof) Infer () : void def prev = all_constraints foreach (c in all_constraints) when (!stop && c.term1 != null) foreach (c' in constraints_by_term [c.term1]) TryCross (c, c') when (!stop && c.term2 != null) foreach (c' in constraints_by_term [c.term2]) TryCross (c, c') if (!stop && prev !== all_constraints) Infer () else {} var_equalities : Hashtable [Term, Term * E] = Hashtable () public ScanForEqualities () : void def a_minus_b = Hashtable () def bounds = Hashtable () var_equalities.Clear () def add_equality (t1, t2, proof) log (APROOF, $"consider add eq: $t1 = $t2 (by $proof)") def replace = if (var_equalities.Contains (t1)) def (t2', _) = var_equalities [t1] if (t2' !== t2) if (t2'.Name == "+" && t2.Name == "+") t2.Child_1of2.SimplerThan (t2'.Child_1of2) else t2.SimplerThan (t2') else false else true when (replace) log (APROOF, $"adding eq: $t1 = $t2 (by $proof)") var_equalities [t1] = (t2, proof) def maybe_equality (c1, c2) log (APROOF, $ "me: $c1 $c2") assert (c1.term1 !== c1.term2, $ "c1=$c1 c2=$c2") when (c1.bound == -c2.bound) if (c1.term1 === c2.term2 && c1.term2 === c2.term1) def c1 = Constraint (c1.neg2, c1.term2, c1.neg1, c1.term1, c1.bound, E.App ("utvpi_swap", [c1.proof])) maybe_equality (c1, c2) else if (c1.term1 === c2.term1 && c1.term2 === c2.term2 && c1.neg1 != c2.neg1 && c1.neg2 != c2.neg2) if (c1.neg1) maybe_equality (c2, c1) else log (APROOF, $"building from $c1 $c2 $(c1.term1.SimplerThan (c1.term2))") def (suff, neg) = if (!c1.neg2) ("p", t => pool.Get ("~", [t])) else ("", t => t) def eq = E.App ("leq_antysymm", [E.App ("utvpi_rev2" + suff, [c2.proof]), c1.proof]) if (c1.term1.SimplerThan (c1.term2)) add_equality (c1.term2, pool.Get ("+", [neg (c1.term1), GetConstant (c2.bound)]), E.App ("minus_eq_def2" + suff, [eq])) else add_equality (c1.term1, pool.Get ("+", [neg (c1.term2), GetConstant (c1.bound)]), E.App ("minus_eq_def1" + suff, [eq])) else {} def add_bound (neg, t, b, c) def c = if (c.term1 == null) E.App ("utvpi_swap", [c.proof]) else c.proof unless (bounds.Contains (t)) bounds [t] = (None (), None ()) if (neg) match (bounds [t]) | (b1, Some ((b2, _))) when -b <= b2 \ | (b1, None) => bounds [t] = (b1, Some ((-b, c))) | _ => {} else match (bounds [t]) | (Some ((b1, _)), b2) when b >= b1 \ | (None, b2) => bounds [t] = (Some ((b, c)), b2) | _ => {} match (bounds [t]) | (Some ((b1, c1)), Some ((b2, c2))) when b1 == b2 => add_equality (t, GetConstant (b1), E.App ("eq_symm", [ E.App ("leq_antysymm", [E.App ("utvpi_drop_zero", [c1]), E.App ("utvpi_rev", [c2])])])) | _ => {} foreach (lst in constraints_by_term.Values) foreach (c in lst) log (APROOF, $"consider: $c") if (c.term1 == null) assert (c.term2 != null) add_bound (c.neg2, c.term2, c.bound, c) else if (c.term2 == null) assert (c.term1 != null) add_bound (c.neg1, c.term1, c.bound, c) else def sig = (c.term1, c.term2) if (a_minus_b.Contains (sig)) foreach (c' in a_minus_b [sig]) maybe_equality (c, c') a_minus_b [sig] ::= c else a_minus_b [sig] = [c] def sig' = (c.term2, c.term1) if (a_minus_b.Contains (sig')) a_minus_b [sig'] ::= c else a_minus_b [sig'] = [c] log (APROOF, $"done: $c") public SimplifyTerm (t : Term) : Term * E def (a, b, c, proof) = Simplify (t) def a = if (a == null) zero else a def b = if (b == null) zero else b (pool.Get ("+", [pool.Get ("+", [a, b]), GetConstant (c)]), proof) // try to convert term to A + B + C form, where A/B are UTVPI variables and C is // a constant Simplify (t : Term) : Term * Term * Z * E def neg (t) if (t == null) null else if (t.Name == "~") t.OnlyChild else pool.Get ("~", [t]) match (t.Name) | _ when var_equalities.Contains (t) => def (t', e1) = var_equalities [t] def (a, b, c, e2) = Simplify (t') (a, b, c, E.App ("eq_trans", [e1, e2])) | "*" => def res (t, name) def e1 = E.App (name, [E.Term (t, prop = false)]) def t' = pool.Get ("+", [t, t]) def (a, b, c, e2) = Simplify (t') (a, b, c, E.App ("eq_trans", [e1, e2])) if (t.Child_1of2 === GetConstant (2)) res (t.Child_2of2, "arith_mul2l") else if (t.Child_2of2 === GetConstant (2)) res (t.Child_1of2, "arith_mul2r") else log (WARN, $"failed to normalize mult: $t") (t, null, 0, E.App ("arith_var_norm", [E.Term (t, false)])) | "-" when t.Arity == 1 with rule_add = "2" \ | "~" with rule_add = "" => def (t1, s1, c1, e1) = Simplify (t.OnlyChild) (neg (t1), neg (s1), -c1, E.App ("arith_neg_norm" + rule_add, [e1])) | "-" | "+" => def (t1, s1, c1, e1) = Simplify (t.Child_1of2) def (t2, s2, c2, e2) = Simplify (t.Child_2of2) if (t1 != null && s1 == null && t2 != null && s2 == null) if (t.Name == "-") def ver = if (t2.Name == "~") "3" else "2" (t1, neg (t2), c1 - c2, E.App ($"arith_minus_norm$ver", [e1, e2])) else (t1, t2, c1 + c2, E.App ("arith_plus_norm2", [e1, e2])) else if (t1 != null && t2 != null) broken = true (null, null, 0, null) else if (t2 == null) assert (s2 == null) if (t.Name == "-") (t1, s1, c1 - c2, E.App ("arith_minus_norm", [e1, e2])) else (t1, s1, c1 + c2, E.App ("arith_plus_norm", [e1, e2])) else assert (t1 == null && s1 == null) if (t.Name == "-") (neg (t2), neg (s2), c1 - c2, E.App ("arith_minus_norm", [e1, e2])) else (t2, s2, c1 + c2, E.App ("arith_plus_norm", [e1, e2])) | n when t.Arity == 0 && LinearTheory.IsNumber (n) => (null, null, int.Parse (n), E.App ("arith_const_norm", [E.Term (t, false)])) | _ => (t, null, 0, E.App ("arith_var_norm", [E.Term (t, false)])) DoMakeConstraints (neg : bool, head : string, t1 : Term, t2 : Term, proof : E) : void match (head, neg) | ("=", false) => DoMakeConstraints (false, "<=", t1, t2, E.App ("arith_eq", [proof])) DoMakeConstraints (false, "<=", t2, t1, E.App ("arith_eq", [E.App ("eq_symm", [proof])])) | ("<", true) => DoMakeConstraints (false, "<=", t2, t1, E.App ("arith_neg_lt", [proof])) | (">", true) => DoMakeConstraints (false, "<=", t1, t2, E.App ("arith_neg_gt", [proof])) | (">=", false) => DoMakeConstraints (false, "<=", t2, t1, E.App ("arith_ge", [proof])) | ("<", false) => DoMakeConstraints (false, "<=", pool.Get ("+", [t1, one]), t2, E.App ("arith_lt", [proof])) | (">", false) => DoMakeConstraints (false, "<=", pool.Get ("+", [t2, one]), t1, E.App ("arith_gt", [proof])) | (">=", true) => DoMakeConstraints (false, "<=", pool.Get ("+", [t1, one]), t2, E.App ("arith_neg_ge", [proof])) | ("<=", true) => DoMakeConstraints (false, "<=", pool.Get ("+", [t2, one]), t1, E.App ("arith_neg_le", [proof])) | ("<=", false) => def (t, s, c, simpl_proof) = Simplify (pool.Get ("-", [t1, t2])) if (broken) log (WARN, $"failed to normalize: $t1 - $t2") broken = false else SaveConstraint (Constraint (t, s, -c, E.App ("arith_leq_norm", [simpl_proof, proof]))) | _ => assert (false) SaveConstraint (c : Constraint) : void log (APROOF, $ "saving: $c") assert (c.term1 != null || !c.neg1) assert (c.term2 != null || !c.neg2) when (!stop) if (c.term1 == null && c.term2 == null) when (c.bound < 0) stop = true match (c.proof) | E.App ("utvpi_contr1", _) \ | E.App ("utvpi_contr2", _) => contradiction = c.proof | _ => contradiction = E.App ("utvpi_contr_const", [c.proof]) else if (c.term1 === c.term2) if (c.neg1 != c.neg2) when (c.bound < 0) stop = true def swaped = E.App ("utvpi_swap", [c.proof]) def (c1, c2) = if (c.neg1) (swaped, c.proof) else (c.proof, swaped) contradiction = E.App ("utvpi_contr2", [c1, c2]) else SaveConstraint (Constraint (c.neg1, c.term1, false, null, Div2 (c.bound), E.App ("utvpi_tight2", [c.proof]))) else mutable lst when (c.term1 != null) unless (constraints_by_term.TryGetValue (c.term1, out lst)) lst = [] lst ::= c constraints_by_term [c.term1] = lst when (c.term2 != null) unless (constraints_by_term.TryGetValue (c.term2, out lst)) lst = [] lst ::= c constraints_by_term [c.term2] = lst all_constraints ::= c MakeConstraints (l : Literal) : void if (l.atom.term2 == null) def t = l.atom.term1 match (t.Name) | "<=" | "<" | ">" | ">=" => DoMakeConstraints (l.IsNeg, t.Name, t.Child_1of2, t.Child_2of2, get_lit (l)) | _ => {} else if (l.IsNeg) {} else DoMakeConstraints (false, "=", l.atom.term1, l.atom.term2, get_lit (l)) public this (get_lit : Literal -> E, core : Core) this.get_lit = get_lit pool = core.TermPool zero = pool.Get ("0", []) one = pool.Get ("1", []) public TryFindContradiction (lits : list [Literal]) : E lits.Iter (MakeConstraints) Infer () if (contradiction != null) //log (TEMP, $ "found arith proof: $(contradiction)") contradiction else null