diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/IApplTerm.java b/nabl2.terms/src/main/java/mb/nabl2/terms/IApplTerm.java index d5e70349c..d439f57a6 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/IApplTerm.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/IApplTerm.java @@ -12,4 +12,7 @@ public interface IApplTerm extends ITerm { @Override IApplTerm withAttachments(IAttachments value); + @Override default Tag termTag() { + return Tag.IApplTerm; + } } diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/IBlobTerm.java b/nabl2.terms/src/main/java/mb/nabl2/terms/IBlobTerm.java index 654887244..d96dfadf2 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/IBlobTerm.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/IBlobTerm.java @@ -7,4 +7,8 @@ public interface IBlobTerm extends ITerm { @Override IBlobTerm withAttachments(IAttachments value); + @Override default Tag termTag() { + return Tag.IBlobTerm; + } + } diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/IConsTerm.java b/nabl2.terms/src/main/java/mb/nabl2/terms/IConsTerm.java index 4555448ad..017d9c87f 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/IConsTerm.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/IConsTerm.java @@ -9,4 +9,11 @@ public interface IConsTerm extends IListTerm { @Override IConsTerm withAttachments(IAttachments value); + @Override default ITerm.Tag termTag() { + return ITerm.Tag.IConsTerm; + } + + @Override default IListTerm.Tag listTermTag() { + return IListTerm.Tag.IConsTerm; + } } \ No newline at end of file diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/IIntTerm.java b/nabl2.terms/src/main/java/mb/nabl2/terms/IIntTerm.java index 57ecd5af9..851d99edf 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/IIntTerm.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/IIntTerm.java @@ -7,4 +7,8 @@ public interface IIntTerm extends ITerm { @Override IIntTerm withAttachments(IAttachments value); + @Override default Tag termTag() { + return Tag.IIntTerm; + } + } diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/IListTerm.java b/nabl2.terms/src/main/java/mb/nabl2/terms/IListTerm.java index cb3f296f1..fc0b4dc20 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/IListTerm.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/IListTerm.java @@ -45,4 +45,12 @@ default T apply(IListTerm list) throws E { @Override IListTerm withAttachments(IAttachments value); + Tag listTermTag(); + + enum Tag { + IConsTerm, + INilTerm, + ITermVar + } + } diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/INilTerm.java b/nabl2.terms/src/main/java/mb/nabl2/terms/INilTerm.java index d6db08866..6f62ac946 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/INilTerm.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/INilTerm.java @@ -5,4 +5,12 @@ public interface INilTerm extends IListTerm { @Override INilTerm withAttachments(IAttachments value); + @Override default ITerm.Tag termTag() { + return ITerm.Tag.INilTerm; + } + + @Override default IListTerm.Tag listTermTag() { + return IListTerm.Tag.INilTerm; + } + } \ No newline at end of file diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/IStringTerm.java b/nabl2.terms/src/main/java/mb/nabl2/terms/IStringTerm.java index 2124ee155..b31a65fec 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/IStringTerm.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/IStringTerm.java @@ -7,4 +7,8 @@ public interface IStringTerm extends ITerm { @Override IStringTerm withAttachments(IAttachments value); + @Override default Tag termTag() { + return Tag.IStringTerm; + } + } diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/ITerm.java b/nabl2.terms/src/main/java/mb/nabl2/terms/ITerm.java index 8484e4d03..d88e1bd32 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/ITerm.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/ITerm.java @@ -30,13 +30,13 @@ interface Cases { T caseAppl(IApplTerm appl); - T caseList(IListTerm cons); + T caseList(IListTerm list); T caseString(IStringTerm string); T caseInt(IIntTerm integer); - T caseBlob(IBlobTerm integer); + T caseBlob(IBlobTerm blob); T caseVar(ITermVar var); @@ -48,13 +48,13 @@ interface CheckedCases { T caseAppl(IApplTerm appl) throws E; - T caseList(IListTerm cons) throws E; + T caseList(IListTerm list) throws E; T caseString(IStringTerm string) throws E; T caseInt(IIntTerm integer) throws E; - T caseBlob(IBlobTerm integer) throws E; + T caseBlob(IBlobTerm blob) throws E; T caseVar(ITermVar var) throws E; @@ -64,4 +64,16 @@ default T caseLock(ITerm term) throws E { } + Tag termTag(); + + enum Tag { + IApplTerm, + IConsTerm, + INilTerm, + IStringTerm, + IIntTerm, + IBlobTerm, + ITermVar, + } + } diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/ITermVar.java b/nabl2.terms/src/main/java/mb/nabl2/terms/ITermVar.java index c9896f8ca..8d69b8d31 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/ITermVar.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/ITermVar.java @@ -19,4 +19,12 @@ public interface ITermVar extends ITerm, IListTerm, Comparable { return c; } + @Override default ITerm.Tag termTag() { + return ITerm.Tag.ITermVar; + } + + @Override default IListTerm.Tag listTermTag() { + return IListTerm.Tag.ITermVar; + } + } \ No newline at end of file diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/ListTerms.java b/nabl2.terms/src/main/java/mb/nabl2/terms/ListTerms.java index b39762310..e53f185e3 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/ListTerms.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/ListTerms.java @@ -174,8 +174,20 @@ public IListTerm.CheckedCases otherwise(final CheckedFunction1 toStringTail(cons.getHead(), cons.getTail()), nil -> "[]", var -> var.toString())); + switch(list.listTermTag()) { + case IConsTerm: { + IConsTerm cons = (IConsTerm) list; + return toStringTail(cons.getHead(), cons.getTail()); + } + case INilTerm: { + return "[]"; + } + case ITermVar: { + return list.toString(); + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for IListTerm subclass/tag"); } private static String toStringTail(ITerm head, IListTerm tail) { diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/build/AConsTerm.java b/nabl2.terms/src/main/java/mb/nabl2/terms/build/AConsTerm.java index ab25d39be..6623d69fa 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/build/AConsTerm.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/build/AConsTerm.java @@ -8,13 +8,14 @@ import org.immutables.value.Value; import org.metaborg.util.collection.CapsuleUtil; import org.metaborg.util.functions.Action1; +import org.metaborg.util.functions.Function2; import io.usethesource.capsule.Set; import mb.nabl2.terms.IConsTerm; import mb.nabl2.terms.IListTerm; +import mb.nabl2.terms.INilTerm; import mb.nabl2.terms.ITerm; import mb.nabl2.terms.ITermVar; -import mb.nabl2.terms.ListTerms; @Value.Immutable(lazyhash = false) @Serial.Version(value = 42L) @@ -94,23 +95,30 @@ abstract class AConsTerm extends AbstractTerm implements IConsTerm { StringBuilder sb = new StringBuilder(); sb.append("["); sb.append(getHead()); - getTail().match(ListTerms.casesFix( - // @formatter:off - (f,cons) -> { + toString(getTail(), sb); + sb.append("]"); + return sb.toString(); + } + + private static void toString(IListTerm subj, StringBuilder sb) { + switch(subj.listTermTag()) { + case IConsTerm: { IConsTerm cons = (IConsTerm) subj; sb.append(","); sb.append(cons.getHead()); - return cons.getTail().match(f); - }, - (f,nil) -> unit, - (f,var) -> { + toString(cons.getTail(), sb); + break; + } + + case INilTerm: { + break; + } + + case ITermVar: { ITermVar var = (ITermVar) subj; sb.append("|"); sb.append(var); - return unit; + break; } - // @formatter:on - )); - sb.append("]"); - return sb.toString(); + } } } diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/build/ListTermIterator.java b/nabl2.terms/src/main/java/mb/nabl2/terms/build/ListTermIterator.java index dbeacd6a9..5b8f23658 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/build/ListTermIterator.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/build/ListTermIterator.java @@ -3,9 +3,13 @@ import java.util.Iterator; import java.util.NoSuchElementException; +import org.metaborg.util.functions.Function1; + +import mb.nabl2.terms.IConsTerm; import mb.nabl2.terms.IListTerm; +import mb.nabl2.terms.INilTerm; import mb.nabl2.terms.ITerm; -import mb.nabl2.terms.ListTerms; +import mb.nabl2.terms.ITermVar; public class ListTermIterator implements Iterator { @@ -16,20 +20,42 @@ public ListTermIterator(IListTerm list) { } @Override public boolean hasNext() { - return current.match(ListTerms.cases(cons -> true, nil -> false, var -> { - throw new IllegalStateException("Cannot iterate over a non-ground list."); - })); + IListTerm subj = current; + switch(subj.listTermTag()) { + case IConsTerm: { + return true; + } + + case INilTerm: { + return false; + } + + case ITermVar: { + throw new IllegalStateException("Cannot iterate over a non-ground list."); + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for IListTerm subclass/tag"); } @Override public ITerm next() { - return current.match(ListTerms.cases(cons -> { - current = cons.getTail(); - return cons.getHead(); - }, nil -> { - throw new NoSuchElementException(); - }, var -> { - throw new IllegalStateException("Cannot iterate over a non-ground list."); - })); + IListTerm subj = current; + switch(subj.listTermTag()) { + case IConsTerm: { IConsTerm cons = (IConsTerm) subj; + current = cons.getTail(); + return cons.getHead(); + } + + case INilTerm: { + throw new NoSuchElementException(); + } + + case ITermVar: { + throw new IllegalStateException("Cannot iterate over a non-ground list."); + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for IListTerm subclass/tag"); } } \ No newline at end of file diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/matching/ApplPattern.java b/nabl2.terms/src/main/java/mb/nabl2/terms/matching/ApplPattern.java index 68a13d65a..28ef3ba58 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/matching/ApplPattern.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/matching/ApplPattern.java @@ -15,7 +15,12 @@ import com.google.common.collect.ImmutableList; import io.usethesource.capsule.Set; +import mb.nabl2.terms.IApplTerm; import mb.nabl2.terms.IAttachments; +import mb.nabl2.terms.IBlobTerm; +import mb.nabl2.terms.IIntTerm; +import mb.nabl2.terms.IListTerm; +import mb.nabl2.terms.IStringTerm; import mb.nabl2.terms.ITerm; import mb.nabl2.terms.ITermVar; import mb.nabl2.terms.Terms; @@ -57,22 +62,32 @@ public List getArgs() { @Override protected boolean matchTerm(ITerm term, ISubstitution.Transient subst, IUnifier.Immutable unifier, Eqs eqs) { - // @formatter:off - return unifier.findTerm(term).match(Terms.cases() - .appl(applTerm -> { - if(applTerm.getArity() == this.args.size() && applTerm.getOp().equals(op)) { + final ApplPattern pattern = this; + ITerm subj = unifier.findTerm(term); + switch(subj.termTag()) { + case IApplTerm: { IApplTerm applTerm = (IApplTerm) subj; + if(applTerm.getArity() == pattern.args.size() && applTerm.getOp().equals(op)) { return matchTerms(args, applTerm.getArgs(), subst, unifier, eqs); } else { return false; } - }).var(v -> { - eqs.add(v, this); + } + + case ITermVar: { ITermVar v = (ITermVar) subj; + eqs.add(v, pattern); return true; - }).otherwise(t -> { + } + + case IConsTerm: + case INilTerm: + case IStringTerm: + case IIntTerm: + case IBlobTerm: { return false; - }) - ); - // @formatter:on + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for ITerm subclass/tag"); } @Override public Pattern apply(IRenaming subst) { diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/matching/CheckedTermMatch.java b/nabl2.terms/src/main/java/mb/nabl2/terms/matching/CheckedTermMatch.java index 9c493cbb8..e1c41befb 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/matching/CheckedTermMatch.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/matching/CheckedTermMatch.java @@ -1,5 +1,6 @@ package mb.nabl2.terms.matching; +import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.function.Function; @@ -9,9 +10,8 @@ import org.metaborg.util.functions.CheckedFunction3; import org.metaborg.util.functions.CheckedFunction4; -import com.google.common.collect.Lists; - import mb.nabl2.terms.IApplTerm; +import mb.nabl2.terms.IBlobTerm; import mb.nabl2.terms.IConsTerm; import mb.nabl2.terms.IIntTerm; import mb.nabl2.terms.IListTerm; @@ -20,7 +20,6 @@ import mb.nabl2.terms.ITerm; import mb.nabl2.terms.ITermVar; import mb.nabl2.terms.ListTerms; -import mb.nabl2.terms.Terms; import mb.nabl2.terms.unification.Unifiers; import mb.nabl2.terms.unification.u.IUnifier; @@ -33,35 +32,48 @@ public static class CM { // term public ICheckedMatcher term() { - return (term, unifier) -> Optional.of(term); + return (term, unifier) -> { + return Optional.of(term); + }; } public ICheckedMatcher term(CheckedFunction1 f) { - return (term, unifier) -> Optional.of(f.apply(term)); + return (term, unifier) -> { + return Optional.of(f.apply(term)); + }; } public ICheckedMatcher term(ITerm.CheckedCases, E> cases) { - return (term, unifier) -> unifier.findTerm(term).matchOrThrow(cases); + return (term, unifier) -> { + return unifier.findTerm(term).matchOrThrow(cases); + }; } // appl public ICheckedMatcher appl(CheckedFunction1 f) { - return (term, unifier) -> unifier.findTerm(term) - .matchOrThrow(Terms., E>checkedCases(appl -> Optional.of(f.apply(appl)), this::empty, - this::empty, this::empty, this::empty, this::empty)); + return (term, unifier) -> { + final ITerm subj = unifier.findTerm(term); + if(subj.termTag() == ITerm.Tag.IApplTerm) { + return Optional.of(f.apply((IApplTerm) subj)); + } + return Optional.empty(); + }; } public ICheckedMatcher appl0(String op, CheckedFunction1 f) { return (term, unifier) -> { - return unifier.findTerm(term).matchOrThrow(Terms., E>checkedCases(appl -> { + final ITerm subj = unifier.findTerm(term); + if(subj.termTag() == ITerm.Tag.IApplTerm) { + IApplTerm appl = (IApplTerm) subj; if(!(appl.getArity() == 0 && op.equals(appl.getOp()))) { return Optional.empty(); } return Optional.of(f.apply(appl)); - }, this::empty, this::empty, this::empty, this::empty, this::empty)); + } + return Optional.empty(); }; } @@ -69,7 +81,9 @@ public ICheckedMatcher appl1(String op, ICheckedMatcher m, CheckedFunction2 f) { return (term, unifier) -> { - return unifier.findTerm(term).matchOrThrow(Terms., E>checkedCases(appl -> { + final ITerm subj = unifier.findTerm(term); + if(subj.termTag() == ITerm.Tag.IApplTerm) { + IApplTerm appl = (IApplTerm) subj; if(!(appl.getArity() == 1 && op.equals(appl.getOp()))) { return Optional.empty(); } @@ -79,7 +93,8 @@ public ICheckedMatcher appl1(String op, } T t = o1.get(); return Optional.of(f.apply(appl, t)); - }, this::empty, this::empty, this::empty, this::empty, this::empty)); + } + return Optional.empty(); }; } @@ -87,7 +102,9 @@ public ICheckedMatcher appl2(String op, ICheckedMatcher m1, ICheckedMatcher m2, CheckedFunction3 f) { return (term, unifier) -> { - return unifier.findTerm(term).matchOrThrow(Terms., E>checkedCases(appl -> { + final ITerm subj = unifier.findTerm(term); + if(subj.termTag() == ITerm.Tag.IApplTerm) { + IApplTerm appl = (IApplTerm) subj; if(!(appl.getArity() == 2 && op.equals(appl.getOp()))) { return Optional.empty(); } @@ -102,7 +119,8 @@ public ICheckedMatcher appl2(String op, } T2 t2 = o2.get(); return Optional.of(f.apply(appl, t1, t2)); - }, this::empty, this::empty, this::empty, this::empty, this::empty)); + } + return Optional.empty(); }; } @@ -111,7 +129,9 @@ public ICheckedMatcher appl3(String o ICheckedMatcher m3, CheckedFunction4 f) { return (term, unifier) -> { - return unifier.findTerm(term).matchOrThrow(Terms., E>checkedCases(appl -> { + final ITerm subj = unifier.findTerm(term); + if(subj.termTag() == ITerm.Tag.IApplTerm) { + IApplTerm appl = (IApplTerm) subj; if(!(appl.getArity() == 3 && op.equals(appl.getOp()))) { return Optional.empty(); } @@ -131,7 +151,8 @@ public ICheckedMatcher appl3(String o } T3 t3 = o3.get(); return Optional.of(f.apply(appl, t1, t2, t3)); - }, this::empty, this::empty, this::empty, this::empty, this::empty)); + } + return Optional.empty(); }; } @@ -139,26 +160,34 @@ public ICheckedMatcher appl3(String o public ICheckedMatcher list(CheckedFunction1 f) { - final CheckedFunction1, E> g = list -> Optional.of(f.apply(list)); return (term, unifier) -> { - return unifier.findTerm(term).matchOrThrow( - Terms., E>checkedCases(this::empty, g, this::empty, this::empty, this::empty, g)); + final ITerm subj = unifier.findTerm(term); + if(subj instanceof IListTerm) { + IListTerm list = (IListTerm) subj; + return Optional.of(f.apply(list)); + } + return Optional.empty(); }; } public ICheckedMatcher list(IListTerm.CheckedCases, E> cases) { - final CheckedFunction1, E> g = list -> list.matchOrThrow(cases); return (term, unifier) -> { - return unifier.findTerm(term).matchOrThrow( - Terms., E>checkedCases(this::empty, g, this::empty, this::empty, this::empty, g)); + final ITerm subj = unifier.findTerm(term); + if(subj instanceof IListTerm) { + IListTerm list = (IListTerm) subj; + return list.matchOrThrow(cases); + } + return Optional.empty(); }; } public ICheckedMatcher listElems(ICheckedMatcher m, CheckedFunction2, R, ? extends E> f) { return (term, unifier) -> { - return unifier.findTerm(term).matchOrThrow(Terms., E>checkedCases(this::empty, list -> { - List ts = Lists.newArrayList(); + final ITerm subj = unifier.findTerm(term); + if(subj instanceof IListTerm && subj.termTag() != ITerm.Tag.ITermVar) { + IListTerm list = (IListTerm) subj; + List ts = new ArrayList<>(); for(ITerm t : ListTerms.iterable(list)) { Optional o = m.matchOrThrow(t, unifier); if(!o.isPresent()) { @@ -167,55 +196,74 @@ public ICheckedMatcher listElems(ICheckedMatch ts.add(o.get()); } return Optional.of(f.apply(list, ts)); - }, this::empty, this::empty, this::empty, this::empty)); + } + return Optional.empty(); }; } public ICheckedMatcher cons(CheckedFunction1 f) { - return (term, unifier) -> unifier.findTerm(term) - .matchOrThrow(Terms., E>checkedCases(this::empty, list -> { - return list.matchOrThrow(ListTerms., E>checkedCases( - cons -> Optional.of(f.apply(cons)), nil -> Optional.empty(), var -> Optional.empty())); - }, this::empty, this::empty, this::empty, this::empty)); - + return (term, unifier) -> { + final ITerm subj = unifier.findTerm(term); + if(subj.termTag() == ITerm.Tag.IConsTerm) { + IConsTerm cons = (IConsTerm) subj; + return Optional.of(f.apply(cons)); + } + return Optional.empty(); + }; } public ICheckedMatcher nil(CheckedFunction1 f) { - return (term, unifier) -> unifier.findTerm(term) - .matchOrThrow(Terms., E>checkedCases(this::empty, list -> { - return list.matchOrThrow(ListTerms., E>checkedCases(cons -> Optional.empty(), - nil -> Optional.of(f.apply(nil)), var -> Optional.empty())); - }, this::empty, this::empty, this::empty, this::empty)); + return (term, unifier) -> { + final ITerm subj = unifier.findTerm(term); + if(subj.termTag() == ITerm.Tag.INilTerm) { + INilTerm nil = (INilTerm) subj; + return Optional.of(f.apply(nil)); + } + return Optional.empty(); + }; } // integer public ICheckedMatcher integer(CheckedFunction1 f) { - return (term, unifier) -> unifier.findTerm(term) - .matchOrThrow(Terms., E>checkedCases(this::empty, this::empty, this::empty, - string -> Optional.of(f.apply(string)), this::empty, this::empty)); + return (term, unifier) -> { + final ITerm subj = unifier.findTerm(term); + if(subj.termTag() == ITerm.Tag.IIntTerm) { + IIntTerm integer = (IIntTerm) subj; + return Optional.of(f.apply(integer)); + } + return Optional.empty(); + }; } // string public ICheckedMatcher string(CheckedFunction1 f) { - return (term, unifier) -> unifier.findTerm(term) - .matchOrThrow(Terms., E>checkedCases(this::empty, this::empty, - string -> Optional.of(f.apply(string)), this::empty, this::empty, this::empty)); + return (term, unifier) -> { + final ITerm subj = unifier.findTerm(term); + if(subj.termTag() == ITerm.Tag.IStringTerm) { + IStringTerm string = (IStringTerm) subj; + return Optional.of(f.apply(string)); + } + return Optional.empty(); + }; } // var public ICheckedMatcher var(CheckedFunction1 f) { - return (term, unifier) -> unifier.findTerm(term) - .matchOrThrow(Terms., E>checkedCases(this::empty, list -> { - return list.matchOrThrow(ListTerms., E>checkedCases(cons -> Optional.empty(), - nil -> Optional.empty(), var -> Optional.of(f.apply(var)))); - }, this::empty, this::empty, this::empty, var -> Optional.of(f.apply(var)))); + return (term, unifier) -> { + final ITerm subj = unifier.findTerm(term); + if(subj.termTag() == ITerm.Tag.ITermVar) { + ITermVar var = (ITermVar) subj; + return Optional.of(f.apply(var)); + } + return Optional.empty(); + }; } // cases @@ -235,10 +283,6 @@ public ICheckedMatcher string(CheckedFunction1 Optional empty(@SuppressWarnings("unused") ITerm term) { - return Optional.empty(); - } - } @FunctionalInterface @@ -251,11 +295,15 @@ default Optional matchOrThrow(ITerm term) throws E { } default ICheckedMatcher map(Function fun) { - return (term, unifier) -> this.matchOrThrow(term, unifier).map(fun); + return (term, unifier) -> { + return this.matchOrThrow(term, unifier).map(fun); + }; } default ICheckedMatcher flatMap(Function> fun) { - return (term, unifier) -> this.matchOrThrow(term, unifier).flatMap(fun); + return (term, unifier) -> { + return this.matchOrThrow(term, unifier).flatMap(fun); + }; } } diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/matching/ConsPattern.java b/nabl2.terms/src/main/java/mb/nabl2/terms/matching/ConsPattern.java index f19e56b6d..892a3e7fc 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/matching/ConsPattern.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/matching/ConsPattern.java @@ -1,8 +1,5 @@ package mb.nabl2.terms.matching; -import static mb.nabl2.terms.build.TermBuild.B; -import static mb.nabl2.terms.matching.TermMatch.M; - import java.util.Objects; import java.util.Optional; @@ -14,13 +11,17 @@ import io.usethesource.capsule.Set; import mb.nabl2.terms.IAttachments; +import mb.nabl2.terms.IConsTerm; import mb.nabl2.terms.IListTerm; import mb.nabl2.terms.ITerm; import mb.nabl2.terms.ITermVar; -import mb.nabl2.terms.ListTerms; import mb.nabl2.terms.substitution.IRenaming; import mb.nabl2.terms.substitution.ISubstitution; +import mb.nabl2.terms.unification.Unifiers; import mb.nabl2.terms.unification.u.IUnifier; +import mb.nabl2.terms.unification.ud.IUniDisunifier; + +import static mb.nabl2.terms.build.TermBuild.B; class ConsPattern extends Pattern { private static final long serialVersionUID = 1L; @@ -55,21 +56,31 @@ public Pattern getTail() { @Override protected boolean matchTerm(ITerm term, ISubstitution.Transient subst, IUnifier.Immutable unifier, Eqs eqs) { - // @formatter:off - return M.list(listTerm -> { - return listTerm.match(ListTerms.cases() - .cons(consTerm -> { + ITerm subj = Unifiers.Immutable.of().findTerm(unifier.findTerm(term)); + if(subj instanceof IListTerm) { + final IListTerm list = (IListTerm) subj; + switch(list.listTermTag()) { + case IConsTerm: { + IConsTerm consTerm = (IConsTerm) list; return matchTerms(Iterables2.from(head, tail), - Iterables2.from(consTerm.getHead(), consTerm.getTail()), subst, unifier, eqs); - }).var(v -> { - eqs.add(v, this); - return true; - }).otherwise(t -> { + Iterables2.from(consTerm.getHead(), consTerm.getTail()), subst, unifier, + eqs); + } + + case INilTerm: { return false; - }) - ); - }).match(unifier.findTerm(term)).orElse(false); - // @formatter:on + } + + case ITermVar: { + eqs.add((ITermVar) list, this); + return true; + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for IListTerm subclass/tag"); + } else { + return false; + } } @Override public ConsPattern apply(IRenaming subst) { diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/matching/IntPattern.java b/nabl2.terms/src/main/java/mb/nabl2/terms/matching/IntPattern.java index 538f0ccdd..c1297b40a 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/matching/IntPattern.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/matching/IntPattern.java @@ -11,7 +11,12 @@ import org.metaborg.util.functions.Function1; import io.usethesource.capsule.Set; +import mb.nabl2.terms.IApplTerm; import mb.nabl2.terms.IAttachments; +import mb.nabl2.terms.IBlobTerm; +import mb.nabl2.terms.IIntTerm; +import mb.nabl2.terms.IListTerm; +import mb.nabl2.terms.IStringTerm; import mb.nabl2.terms.ITerm; import mb.nabl2.terms.ITermVar; import mb.nabl2.terms.Terms; @@ -43,18 +48,28 @@ public int getValue() { @Override protected boolean matchTerm(ITerm term, ISubstitution.Transient subst, IUnifier.Immutable unifier, Eqs eqs) { - // @formatter:off - return unifier.findTerm(term).match(Terms.cases() - .integer(intTerm -> { + final IntPattern pattern = this; + ITerm subj = unifier.findTerm(term); + switch(subj.termTag()) { + case IIntTerm: { IIntTerm intTerm = (IIntTerm) subj; return intTerm.getValue() == value; - }).var(v -> { - eqs.add(v, this); + } + + case ITermVar: { ITermVar v = (ITermVar) subj; + eqs.add(v, pattern); return true; - }).otherwise(t -> { + } + + case IApplTerm: + case IConsTerm: + case INilTerm: + case IStringTerm: + case IBlobTerm: { return false; - }) - ); - // @formatter:on + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for ITerm subclass/tag"); } @Override public IntPattern apply(IRenaming subst) { diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/matching/NilPattern.java b/nabl2.terms/src/main/java/mb/nabl2/terms/matching/NilPattern.java index 74621aea8..5f5ca42b5 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/matching/NilPattern.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/matching/NilPattern.java @@ -1,24 +1,29 @@ package mb.nabl2.terms.matching; -import static mb.nabl2.terms.build.TermBuild.B; -import static mb.nabl2.terms.matching.TermMatch.M; - import java.util.Optional; import org.metaborg.util.collection.CapsuleUtil; import org.metaborg.util.functions.Action2; import org.metaborg.util.functions.Function0; import org.metaborg.util.functions.Function1; +import org.metaborg.util.iterators.Iterables2; import io.usethesource.capsule.Set; import mb.nabl2.terms.IAttachments; +import mb.nabl2.terms.IConsTerm; +import mb.nabl2.terms.IListTerm; +import mb.nabl2.terms.INilTerm; import mb.nabl2.terms.ITerm; import mb.nabl2.terms.ITermVar; import mb.nabl2.terms.ListTerms; import mb.nabl2.terms.substitution.IRenaming; import mb.nabl2.terms.substitution.ISubstitution; +import mb.nabl2.terms.unification.Unifiers; import mb.nabl2.terms.unification.u.IUnifier; +import static mb.nabl2.terms.build.TermBuild.B; +import static mb.nabl2.terms.matching.TermMatch.M; + class NilPattern extends Pattern { private static final long serialVersionUID = 1L; @@ -36,20 +41,28 @@ public NilPattern(IAttachments attachments) { @Override protected boolean matchTerm(ITerm term, ISubstitution.Transient subst, IUnifier.Immutable unifier, Eqs eqs) { - // @formatter:off - return M.list(listTerm -> { - return listTerm.match(ListTerms.cases() - .nil(nilTerm -> { + ITerm subj = Unifiers.Immutable.of().findTerm(unifier.findTerm(term)); + if(subj instanceof IListTerm) { + final IListTerm list = (IListTerm) subj; + switch(list.listTermTag()) { + case IConsTerm: { + return false; + } + + case INilTerm: { return true; - }).var(v -> { - eqs.add(v, this); + } + + case ITermVar: { + eqs.add((ITermVar) list, this); return true; - }).otherwise(t -> { - return false; - }) - ); - }).match(unifier.findTerm(term)).orElse(false); - // @formatter:on + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for IListTerm subclass/tag"); + } else { + return false; + } } @Override public NilPattern apply(IRenaming subst) { diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/matching/Pattern.java b/nabl2.terms/src/main/java/mb/nabl2/terms/matching/Pattern.java index d60204c2d..33f55ad9a 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/matching/Pattern.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/matching/Pattern.java @@ -1,6 +1,7 @@ package mb.nabl2.terms.matching; import java.io.Serializable; +import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.List; @@ -53,7 +54,7 @@ public Optional match(ITerm term) { public MaybeNotInstantiated> match(ITerm term, IUnifier.Immutable unifier) { final ISubstitution.Transient subst = PersistentSubstitution.Transient.of(); - final List stuckVars = Lists.newArrayList(); + final List stuckVars = new ArrayList<>(); final Eqs eqs = new Eqs() { @Override public void add(ITermVar var, ITerm pattern) { @@ -84,9 +85,9 @@ public Optional matchWithEqs(ITerm term, IUnifier.Immutable unifier // substitution from pattern variables to unifier variables final ISubstitution.Transient _subst = PersistentSubstitution.Transient.of(); // equalities between unifier terms - final List> termEqs = Lists.newArrayList(); + final List> termEqs = new ArrayList<>(); // equalities between unifier variables and patterns - final List> patternEqs = Lists.newArrayList(); + final List> patternEqs = new ArrayList<>(); // match final Eqs eqs = new Eqs() { @@ -122,9 +123,7 @@ public Optional matchWithEqs(ITerm term, IUnifier.Immutable unifier } for(Tuple2 patternEq : patternEqs) { final ITermVar leftVar = patternEq._1(); - final ITerm rightTerm = patternEq._2().asTerm((v, t) -> { - allEqs.add(Tuple2.of(subst.apply(v), subst.apply(t))); - }, (v) -> v.orElseGet(() -> fresh.freshWld())); + final ITerm rightTerm = patternEq._2().asTerm((v, t) -> allEqs.add(Tuple2.of(subst.apply(v), subst.apply(t))), (v) -> v.orElseGet(() -> fresh.freshWld())); stuckVars.add(leftVar); allEqs.add(Tuple2.of(leftVar, subst.apply(rightTerm))); } @@ -159,9 +158,7 @@ protected static boolean matchTerms(final Iterable patterns, final Iter public Tuple2>> asTerm(Function1, ITermVar> fresh) { final ImmutableList.Builder> eqs = ImmutableList.builder(); - final ITerm term = asTerm((v, t) -> { - eqs.add(Tuple2.of(v, t)); - }, fresh); + final ITerm term = asTerm((v, t) -> eqs.add(Tuple2.of(v, t)), fresh); return Tuple2.of(term, eqs.build()); } @@ -207,16 +204,16 @@ public Optional compare(Pattern p1, Pattern p2) { private @Nullable Integer compare(Pattern p1, Pattern p2, AtomicInteger pos, Map vars1, Map vars2) { if(p1 instanceof ApplPattern) { - final ApplPattern appl1 = (ApplPattern) p1; + final ApplPattern appl = (ApplPattern) p1; if(p2 instanceof ApplPattern) { final ApplPattern appl2 = (ApplPattern) p2; - if(!appl1.getOp().equals(appl2.getOp())) { + if(!appl.getOp().equals(appl2.getOp())) { return null; } - if(appl1.getArgs().size() != appl2.getArgs().size()) { + if(appl.getArgs().size() != appl2.getArgs().size()) { return null; } - final Iterator it1 = appl1.getArgs().iterator(); + final Iterator it1 = appl.getArgs().iterator(); final Iterator it2 = appl2.getArgs().iterator(); Integer c = 0; while(c != null && c == 0 && it1.hasNext()) { @@ -239,13 +236,13 @@ public Optional compare(Pattern p1, Pattern p2) { return null; } } else if(p1 instanceof ConsPattern) { - final ConsPattern cons1 = (ConsPattern) p1; + final ConsPattern cons = (ConsPattern) p1; if(p2 instanceof ConsPattern) { final ConsPattern cons2 = (ConsPattern) p2; Integer c = 0; - c = compare(cons1.getHead(), cons2.getHead(), pos, vars1, vars2); + c = compare(cons.getHead(), cons2.getHead(), pos, vars1, vars2); if(c != null && c == 0) { - c = compare(cons1.getTail(), cons2.getTail(), pos, vars1, vars2); + c = compare(cons.getTail(), cons2.getTail(), pos, vars1, vars2); } return c; } else if(p2 instanceof PatternVar) { @@ -282,10 +279,10 @@ public Optional compare(Pattern p1, Pattern p2) { return null; } } else if(p1 instanceof StringPattern) { - final StringPattern string1 = (StringPattern) p1; + final StringPattern string = (StringPattern) p1; if(p2 instanceof StringPattern) { final StringPattern string2 = (StringPattern) p2; - return string1.getValue().equals(string2.getValue()) ? 0 : null; + return string.getValue().equals(string2.getValue()) ? 0 : null; } else if(p2 instanceof PatternVar) { final PatternVar var2 = (PatternVar) p2; if(boundAt(var2, vars2) >= 0) { @@ -302,10 +299,10 @@ public Optional compare(Pattern p1, Pattern p2) { return null; } } else if(p1 instanceof IntPattern) { - final IntPattern integer1 = (IntPattern) p1; + final IntPattern integer = (IntPattern) p1; if(p2 instanceof IntPattern) { final IntPattern integer2 = (IntPattern) p2; - return integer1.getValue() == integer2.getValue() ? 0 : null; + return integer.getValue() == integer2.getValue() ? 0 : null; } else if(p2 instanceof PatternVar) { final PatternVar var2 = (PatternVar) p2; if(boundAt(var2, vars2) >= 0) { @@ -322,13 +319,13 @@ public Optional compare(Pattern p1, Pattern p2) { return null; } } else if(p1 instanceof PatternVar) { - final PatternVar var1 = (PatternVar) p1; - final int i1 = boundAt(var1, vars1); + final PatternVar var = (PatternVar) p1; + final int i1 = boundAt(var, vars1); if(p2 instanceof PatternVar) { final PatternVar var2 = (PatternVar) p2; final int i2 = boundAt(var2, vars2); if(i1 < 0 && i2 < 0) { // neither are bound - bind(var1.getVar(), vars1, var2.getVar(), vars2, pos.getAndIncrement()); + bind(var.getVar(), vars1, var2.getVar(), vars2, pos.getAndIncrement()); return 0; } else if(i1 < 0 && i2 >= 0) { // p2 is bound bind(var2.getVar(), vars1, pos.getAndIncrement()); @@ -347,9 +344,9 @@ public Optional compare(Pattern p1, Pattern p2) { return 1; } } else if(p1 instanceof PatternAs) { - final PatternAs as1 = (PatternAs) p1; - bind(as1.getVar(), vars1, pos.get()); // FIXME what if this is already bound? - return compare(as1.getPattern(), p2, pos, vars1, vars2); + final PatternAs as = (PatternAs) p1; + bind(as.getVar(), vars1, pos.get()); // FIXME what if this is already bound? + return compare(as.getPattern(), p2, pos, vars1, vars2); } else { return null; } @@ -385,11 +382,7 @@ private void bind(@Nullable ITermVar v, Map vars, int pos) { * Note: this comparator imposes orderings that are inconsistent with equals. */ public java.util.Comparator asComparator() { - return new java.util.Comparator() { - @Override public int compare(Pattern p1, Pattern p2) { - return LeftRightOrder.this.compare(p1, p2).orElse(0); - } - }; + return (p1, p2) -> LeftRightOrder.this.compare(p1, p2).orElse(0); } } diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/matching/StringPattern.java b/nabl2.terms/src/main/java/mb/nabl2/terms/matching/StringPattern.java index 84eb041f5..40ef39baf 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/matching/StringPattern.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/matching/StringPattern.java @@ -11,7 +11,12 @@ import org.metaborg.util.functions.Function1; import io.usethesource.capsule.Set; +import mb.nabl2.terms.IApplTerm; import mb.nabl2.terms.IAttachments; +import mb.nabl2.terms.IBlobTerm; +import mb.nabl2.terms.IIntTerm; +import mb.nabl2.terms.IListTerm; +import mb.nabl2.terms.IStringTerm; import mb.nabl2.terms.ITerm; import mb.nabl2.terms.ITermVar; import mb.nabl2.terms.Terms; @@ -43,18 +48,28 @@ public String getValue() { @Override protected boolean matchTerm(ITerm term, ISubstitution.Transient subst, IUnifier.Immutable unifier, Eqs eqs) { - // @formatter:off - return unifier.findTerm(term).match(Terms.cases() - .string(stringTerm -> { + final StringPattern pattern = this; + ITerm subj = unifier.findTerm(term); + switch(subj.termTag()) { + case IStringTerm: { IStringTerm stringTerm = (IStringTerm) subj; return stringTerm.getValue().equals(value); - }).var(v -> { - eqs.add(v, this); + } + + case ITermVar: { ITermVar v = (ITermVar) subj; + eqs.add(v, pattern); return true; - }).otherwise(t -> { + } + + case IApplTerm: + case IConsTerm: + case INilTerm: + case IIntTerm: + case IBlobTerm: { return false; - }) - ); - // @formatter:on + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for ITerm subclass/tag"); } @Override public StringPattern apply(IRenaming subst) { diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/matching/TermMatch.java b/nabl2.terms/src/main/java/mb/nabl2/terms/matching/TermMatch.java index 51631b25a..fa6795c29 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/matching/TermMatch.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/matching/TermMatch.java @@ -2,6 +2,7 @@ import static mb.nabl2.terms.Terms.TUPLE_OP; +import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.function.Function; @@ -32,7 +33,6 @@ import mb.nabl2.terms.ITerm; import mb.nabl2.terms.ITermVar; import mb.nabl2.terms.ListTerms; -import mb.nabl2.terms.Terms; import mb.nabl2.terms.unification.Unifiers; import mb.nabl2.terms.unification.u.IUnifier; @@ -63,8 +63,15 @@ public IMatcher term(ITerm.Cases> cases) { // appl public IMatcher appl() { - return (term, unifier) -> unifier.findTerm(term).match(Terms.>cases(Optional::of, - this::empty, this::empty, this::empty, this::empty, this::empty)); + return (term, unifier) -> { + ITerm subj = unifier.findTerm(term); + if(subj.termTag() == ITerm.Tag.IApplTerm) { + IApplTerm appl = (IApplTerm) subj; + return Optional.of(appl); + } else { + return Optional.empty(); + } + }; } public IMatcher appl(String op, Function1 f) { @@ -72,23 +79,33 @@ public IMatcher appl(String op, Function1 f) { } public IMatcher appl(Function1 f) { - return (term, unifier) -> unifier.findTerm(term) - .match(Terms.>cases(appl -> Optional.of(f.apply(appl)), this::empty, this::empty, - this::empty, this::empty, this::empty)); + return (term, unifier) -> { + ITerm subj = unifier.findTerm(term); + if(subj.termTag() == ITerm.Tag.IApplTerm) { + IApplTerm appl = (IApplTerm) subj; + return Optional.of(f.apply(appl)); + } else { + return Optional.empty(); + } + }; } public IMatcher appl0(String op) { - return appl0(op, (appl) -> appl); + return appl0(op, appl -> appl); } public IMatcher appl0(String op, Function1 f) { return (term, unifier) -> { - return unifier.findTerm(term).match(Terms.>cases(appl -> { + ITerm subj = unifier.findTerm(term); + if(subj.termTag() == ITerm.Tag.IApplTerm) { + IApplTerm appl = (IApplTerm) subj; if(!(appl.getArity() == 0 && op.equals(appl.getOp()))) { return Optional.empty(); } return Optional.of(f.apply(appl)); - }, this::empty, this::empty, this::empty, this::empty, this::empty)); + } else { + return Optional.empty(); + } }; } @@ -99,12 +116,16 @@ public IMatcher appl1(String op, IMatcher m) { public IMatcher appl1(String op, IMatcher m, Function2 f) { return (term, unifier) -> { - return unifier.findTerm(term).match(Terms.>cases(appl -> { + ITerm subj = unifier.findTerm(term); + if(subj.termTag() == ITerm.Tag.IApplTerm) { + IApplTerm appl = (IApplTerm) subj; if(!(appl.getArity() == 1 && op.equals(appl.getOp()))) { return Optional.empty(); } return m.match(appl.getArgs().get(0), unifier).map(t -> f.apply(appl, t)); - }, this::empty, this::empty, this::empty, this::empty, this::empty)); + } else { + return Optional.empty(); + } }; } @@ -115,14 +136,18 @@ public IMatcher appl2(String op, IMatcher m1, public IMatcher appl2(String op, IMatcher m1, IMatcher m2, Function3 f) { return (term, unifier) -> { - return unifier.findTerm(term).match(Terms.>cases(appl -> { + ITerm subj = unifier.findTerm(term); + if(subj.termTag() == ITerm.Tag.IApplTerm) { + IApplTerm appl = (IApplTerm) subj; if(!(appl.getArity() == 2 && op.equals(appl.getOp()))) { return Optional.empty(); } Optional o1 = m1.match(appl.getArgs().get(0), unifier); Optional o2 = m2.match(appl.getArgs().get(1), unifier); return Optionals.lift(o1, o2, (t1, t2) -> f.apply(appl, t1, t2)); - }, this::empty, this::empty, this::empty, this::empty, this::empty)); + } else { + return Optional.empty(); + } }; } @@ -134,15 +159,20 @@ public IMatcher appl3(String op, IMatcher public IMatcher appl3(String op, IMatcher m1, IMatcher m2, IMatcher m3, Function4 f) { return (term, unifier) -> { - return unifier.findTerm(term).match(Terms.>cases(appl -> { + ITerm subj = unifier.findTerm(term); + if(subj.termTag() == ITerm.Tag.IApplTerm) { + IApplTerm appl = (IApplTerm) subj; if(!(appl.getArity() == 3 && op.equals(appl.getOp()))) { return Optional.empty(); } Optional o1 = m1.match(appl.getArgs().get(0), unifier); Optional o2 = m2.match(appl.getArgs().get(1), unifier); Optional o3 = m3.match(appl.getArgs().get(2), unifier); - return Optionals.lift(o1, o2, o3, (t1, t2, t3) -> f.apply(appl, t1, t2, t3)); - }, this::empty, this::empty, this::empty, this::empty, this::empty)); + return Optionals.lift(o1, o2, o3, + (t1, t2, t3) -> f.apply(appl, t1, t2, t3)); + } else { + return Optional.empty(); + } }; } @@ -155,7 +185,9 @@ public IMatcher appl4(String op, IMatcher m IMatcher m3, IMatcher m4, Function5 f) { return (term, unifier) -> { - return unifier.findTerm(term).match(Terms.>cases(appl -> { + ITerm subj = unifier.findTerm(term); + if(subj.termTag() == ITerm.Tag.IApplTerm) { + IApplTerm appl = (IApplTerm) subj; if(!(appl.getArity() == 4 && op.equals(appl.getOp()))) { return Optional.empty(); } @@ -163,8 +195,11 @@ public IMatcher appl4(String op, IMatcher m Optional o2 = m2.match(appl.getArgs().get(1), unifier); Optional o3 = m3.match(appl.getArgs().get(2), unifier); Optional o4 = m4.match(appl.getArgs().get(3), unifier); - return Optionals.lift(o1, o2, o3, o4, (t1, t2, t3, t4) -> f.apply(appl, t1, t2, t3, t4)); - }, this::empty, this::empty, this::empty, this::empty, this::empty)); + return Optionals.lift(o1, o2, o3, o4, + (t1, t2, t3, t4) -> f.apply(appl, t1, t2, t3, t4)); + } else { + return Optional.empty(); + } }; } @@ -178,7 +213,9 @@ public IMatcher appl5(String op, IMatcher m5, Function6 f) { return (term, unifier) -> { - return unifier.findTerm(term).match(Terms.>cases(appl -> { + ITerm subj = unifier.findTerm(term); + if(subj.termTag() == ITerm.Tag.IApplTerm) { + IApplTerm appl = (IApplTerm) subj; if(!(appl.getArity() == 5 && op.equals(appl.getOp()))) { return Optional.empty(); } @@ -188,8 +225,10 @@ public IMatcher appl5(String op, IMatcher o4 = m4.match(appl.getArgs().get(3), unifier); Optional o5 = m5.match(appl.getArgs().get(4), unifier); return Optionals.lift(o1, o2, o3, o4, o5, - (t1, t2, t3, t4, t5) -> f.apply(appl, t1, t2, t3, t4, t5)); - }, this::empty, this::empty, this::empty, this::empty, this::empty)); + (t1, t2, t3, t4, t5) -> f.apply(appl, t1, t2, t3, t4, t5)); + } else { + return Optional.empty(); + } }; } @@ -203,7 +242,9 @@ public IMatcher appl6(String op, IMatcher m5, IMatcher m6, Function7 f) { return (term, unifier) -> { - return unifier.findTerm(term).match(Terms.>cases(appl -> { + ITerm subj = unifier.findTerm(term); + if(subj.termTag() == ITerm.Tag.IApplTerm) { + IApplTerm appl = (IApplTerm) subj; if(!(appl.getArity() == 6 && op.equals(appl.getOp()))) { return Optional.empty(); } @@ -214,8 +255,10 @@ public IMatcher appl6(String op, IMatcher o5 = m5.match(appl.getArgs().get(4), unifier); Optional o6 = m6.match(appl.getArgs().get(5), unifier); return Optionals.lift(o1, o2, o3, o4, o5, o6, - (t1, t2, t3, t4, t5, t6) -> f.apply(appl, t1, t2, t3, t4, t5, t6)); - }, this::empty, this::empty, this::empty, this::empty, this::empty)); + (t1, t2, t3, t4, t5, t6) -> f.apply(appl, t1, t2, t3, t4, t5, t6)); + } else { + return Optional.empty(); + } }; } @@ -285,14 +328,18 @@ public IMatcher tuple5(IMatcher m1, IMa // list public IMatcher list() { - return list((l) -> l); + return list(l -> l); } public IMatcher list(Function1 f) { final Function1> g = list -> Optional.of(f.apply(list)); return (term, unifier) -> { - return unifier.findTerm(term) - .match(Terms.>cases(this::empty, g, this::empty, this::empty, this::empty, g)); + ITerm subj = unifier.findTerm(term); + if(subj instanceof IListTerm) { + return g.apply((IListTerm) subj); + } else { + return Optional.empty(); + } }; } @@ -307,34 +354,44 @@ public IMatcher> listElems(IMatcher m) { public IMatcher listElems(IMatcher m, Function2, R> f) { return (term, unifier) -> { - return unifier.findTerm(term).match(Terms.>cases(this::empty, list -> { - List> os = Lists.newArrayList(); + ITerm subj = unifier.findTerm(term); + if(subj instanceof IListTerm && subj.termTag() != ITerm.Tag.ITermVar) { + IListTerm list = (IListTerm) subj; + List> os = new ArrayList<>(); for(ITerm t : ListTerms.iterable(list)) { os.add(m.match(t, unifier)); } - return Optionals.sequence(os).map(ts -> (R) f.apply(list, ImmutableList.copyOf(ts))); - }, this::empty, this::empty, this::empty, this::empty)); + return Optionals.sequence(os) + .map(ts -> f.apply(list, ImmutableList.copyOf(ts))); + } + return Optional.empty(); }; } public IMatcher cons(Function1 f) { - return (term, unifier) -> unifier.findTerm(term).match(Terms.>cases(this::empty, list -> { - return list.match(ListTerms.>cases(cons -> Optional.of(f.apply(cons)), - nil -> Optional.empty(), var -> Optional.empty())); - }, this::empty, this::empty, this::empty, this::empty)); - + return (term, unifier) -> { + ITerm subj = unifier.findTerm(term); + if(subj.termTag() == ITerm.Tag.IConsTerm) { + IConsTerm cons = (IConsTerm) subj; + return Optional.of(f.apply(cons)); + } + return Optional.empty(); + }; } public IMatcher cons(IMatcher mhd, IMatcher mtl, Function3 f) { - return (term, unifier) -> unifier.findTerm(term).match(Terms.>cases(this::empty, list -> { - return list.match(ListTerms.>cases(cons -> { + return (term, unifier) -> { + ITerm subj = unifier.findTerm(term); + if(subj.termTag() == ITerm.Tag.IConsTerm) { + IConsTerm cons = (IConsTerm) subj; Optional ohd = mhd.match(cons.getHead(), unifier); Optional otl = mtl.match(cons.getTail(), unifier); - return Optionals.lift(ohd, otl, (thd, ttl) -> f.apply(cons, thd, ttl)); - }, this::empty, this::empty)); - }, this::empty, this::empty, this::empty, this::empty)); - + return Optionals.lift(ohd, otl, + (thd, ttl) -> f.apply(cons, thd, ttl)); + } + return Optional.empty(); + }; } public IMatcher nil() { @@ -342,11 +399,14 @@ public IMatcher nil() { } public IMatcher nil(Function1 f) { - return (term, unifier) -> unifier.findTerm(term).match(Terms.>cases(this::empty, list -> { - return list.match(ListTerms.>cases(cons -> Optional.empty(), - nil -> Optional.of(f.apply(nil)), var -> Optional.empty())); - }, this::empty, this::empty, this::empty, this::empty)); - + return (term, unifier) -> { + ITerm subj = unifier.findTerm(term); + if(subj.termTag() == ITerm.Tag.INilTerm) { + INilTerm nil = (INilTerm) subj; + return Optional.of(f.apply(nil)); + } + return Optional.empty(); + }; } // string @@ -356,12 +416,18 @@ public IMatcher string() { } public IMatcher string(Function1 f) { - return (term, unifier) -> unifier.findTerm(term).match(Terms.>cases(this::empty, this::empty, - string -> Optional.of(f.apply(string)), this::empty, this::empty, this::empty)); + return (term, unifier) -> { + ITerm subj = unifier.findTerm(term); + if(subj.termTag() == ITerm.Tag.IStringTerm) { + IStringTerm string = (IStringTerm) subj; + return Optional.of(f.apply(string)); + } + return Optional.empty(); + }; } public IMatcher stringValue() { - return string(s -> s.getValue()); + return string(IStringTerm::getValue); } // integer @@ -371,12 +437,18 @@ public IMatcher integer() { } public IMatcher integer(Function1 f) { - return (term, unifier) -> unifier.findTerm(term).match(Terms.>cases(this::empty, this::empty, - this::empty, integer -> Optional.of(f.apply(integer)), this::empty, this::empty)); + return (term, unifier) -> { + ITerm subj = unifier.findTerm(term); + if(subj.termTag() == ITerm.Tag.IIntTerm) { + IIntTerm integer = (IIntTerm) subj; + return Optional.of(f.apply(integer)); + } + return Optional.empty(); + }; } public IMatcher integerValue() { - return integer(i -> i.getValue()); + return integer(IIntTerm::getValue); } // blob @@ -386,8 +458,14 @@ public IMatcher blob() { } public IMatcher blob(Function1 f) { - return (term, unifier) -> unifier.findTerm(term).match(Terms.>cases(this::empty, this::empty, - this::empty, this::empty, blob -> Optional.of(f.apply(blob)), this::empty)); + return (term, unifier) -> { + ITerm subj = unifier.findTerm(term); + if(subj.termTag() == ITerm.Tag.IBlobTerm) { + IBlobTerm blob = (IBlobTerm) subj; + return Optional.of(f.apply(blob)); + } + return Optional.empty(); + }; } @SuppressWarnings("unchecked") public IMatcher blobValue(Class blobClass) { @@ -407,8 +485,14 @@ public IMatcher var() { } public IMatcher var(Function1 f) { - return (term, unifier) -> unifier.findTerm(term).match(Terms.>cases(this::empty, this::empty, - this::empty, this::empty, this::empty, var -> Optional.of(f.apply(var)))); + return (term, unifier) -> { + ITerm subj = unifier.findTerm(term); + if(subj.termTag() == ITerm.Tag.ITermVar) { + ITermVar var = (ITermVar) subj; + return Optional.of(f.apply(var)); + } + return Optional.empty(); + }; } /** @@ -461,10 +545,8 @@ public IMatcher req(IMatcher matcher) { } public IMatcher req(String msg, IMatcher matcher) { - return (term, unifier) -> { - return matcher.match(term, unifier).map(Optional::of) - .orElseThrow(() -> new IllegalArgumentException(msg + ": " + Unifiers.Immutable.of().toString(term, 4))); - }; + return (term, unifier) -> matcher.match(term, unifier).map(Optional::of) + .orElseThrow(() -> new IllegalArgumentException(msg + ": " + Unifiers.Immutable.of().toString(term, 4))); } @SuppressWarnings("unchecked") public IMatcher preserveAttachments(IMatcher matcher) { @@ -499,11 +581,6 @@ public IMatcher> option(IMatcher matcher) { // @formatter:on } - // util - - private Optional empty(@SuppressWarnings("unused") ITerm term) { - return Optional.empty(); - } } @FunctionalInterface @@ -549,7 +626,7 @@ default IMatcher flatMap(Function> fun) { } static IMatcher flatten(IMatcher> m) { - return m.flatMap(o -> o)::match; + return m.flatMap(o -> o); } } diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/matching/TermPattern.java b/nabl2.terms/src/main/java/mb/nabl2/terms/matching/TermPattern.java index 278180cae..e91ea05e2 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/matching/TermPattern.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/matching/TermPattern.java @@ -1,7 +1,5 @@ package mb.nabl2.terms.matching; -import static mb.nabl2.terms.build.TermBuild.B; - import java.util.Arrays; import java.util.List; import java.util.Optional; @@ -10,15 +8,21 @@ import com.google.common.collect.ImmutableList; +import mb.nabl2.terms.IApplTerm; import mb.nabl2.terms.IAttachments; +import mb.nabl2.terms.IConsTerm; +import mb.nabl2.terms.IIntTerm; +import mb.nabl2.terms.INilTerm; +import mb.nabl2.terms.IStringTerm; import mb.nabl2.terms.ITerm; import mb.nabl2.terms.ITermVar; -import mb.nabl2.terms.ListTerms; import mb.nabl2.terms.Terms; import mb.nabl2.terms.build.Attachments; import mb.nabl2.terms.substitution.ISubstitution; import mb.nabl2.terms.unification.u.IUnifier; +import static mb.nabl2.terms.build.TermBuild.B; + public class TermPattern { public static P P = new P(); @@ -30,7 +34,7 @@ public Pattern newAppl(String op, Pattern... args) { } public Pattern newAppl(String op, Iterable args, IAttachments attachments) { - if(op.equals("")) { + if(op.isEmpty()) { throw new IllegalArgumentException(); } return new ApplPattern(op, args, attachments); @@ -130,29 +134,50 @@ public Pattern fromTerm(ITerm term) { } public Pattern fromTerm(ITerm term, Predicate1 isWildcard) { - // @formatter:off - return term.match(Terms.cases( - appl -> { + switch(term.termTag()) { + case IApplTerm: { + IApplTerm appl = (IApplTerm) term; final List args = appl.getArgs(); - final ImmutableList.Builder newArgs = ImmutableList.builderWithExpectedSize(args.size()); + final ImmutableList.Builder newArgs = + ImmutableList.builderWithExpectedSize(args.size()); for(ITerm arg : args) { newArgs.add(fromTerm(arg, isWildcard)); } return new ApplPattern(appl.getOp(), newArgs.build(), appl.getAttachments()); - }, - list -> list.match(ListTerms.cases( - cons -> new ConsPattern(fromTerm(cons.getHead(), isWildcard), fromTerm(cons.getTail(), isWildcard), cons.getAttachments()), - nil -> new NilPattern(nil.getAttachments()), - var -> isWildcard.test(var) ? new PatternVar() : new PatternVar(var) - )), - string -> new StringPattern(string.getValue(), string.getAttachments()), - integer -> new IntPattern(integer.getValue(), integer.getAttachments()), - blob -> { + } + + case IConsTerm: { + IConsTerm cons = (IConsTerm) term; + return new ConsPattern(fromTerm(cons.getHead(), isWildcard), + fromTerm(cons.getTail(), isWildcard), cons.getAttachments()); + } + + case INilTerm: { + INilTerm nil = (INilTerm) term; + return new NilPattern(nil.getAttachments()); + } + + case IStringTerm: { + IStringTerm string = (IStringTerm) term; + return new StringPattern(string.getValue(), string.getAttachments()); + } + + case IIntTerm: { + IIntTerm integer = (IIntTerm) term; + return new IntPattern(integer.getValue(), integer.getAttachments()); + } + + case IBlobTerm: { throw new IllegalArgumentException("Cannot create blob patterns."); - }, - var -> isWildcard.test(var) ? new PatternVar() : new PatternVar(var) - )); - // @formatter:on + } + + case ITermVar: { + ITermVar var = (ITermVar) term; + return isWildcard.test(var) ? new PatternVar() : new PatternVar(var); + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for ITerm subclass/tag"); } public Optional match(final Iterable patterns, final Iterable terms) { diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/matching/Transform.java b/nabl2.terms/src/main/java/mb/nabl2/terms/matching/Transform.java index a529595c5..29b9569fe 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/matching/Transform.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/matching/Transform.java @@ -1,8 +1,6 @@ package mb.nabl2.terms.matching; -import static mb.nabl2.terms.build.TermBuild.B; -import static mb.nabl2.terms.matching.TermMatch.M; - +import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Optional; @@ -13,13 +11,15 @@ import org.metaborg.util.unit.Unit; import com.google.common.collect.ImmutableList; -import com.google.common.collect.Lists; +import mb.nabl2.terms.IApplTerm; +import mb.nabl2.terms.IConsTerm; import mb.nabl2.terms.IListTerm; import mb.nabl2.terms.ITerm; -import mb.nabl2.terms.ListTerms; import mb.nabl2.terms.Terms; -import mb.nabl2.terms.matching.TermMatch.IMatcher; + +import static mb.nabl2.terms.build.TermBuild.B; +import static mb.nabl2.terms.matching.TermMatch.M; public class Transform { @@ -28,86 +28,109 @@ public class Transform { public static class T { public static Function1 sometd(PartialFunction1 m) { - // @formatter:off - return term -> m.apply(term).orElseGet(() -> term.match(Terms.cases( - (appl) -> { - final ImmutableList newArgs; - if((newArgs = Terms.applyLazy(appl.getArgs(), sometd(m)::apply)) == null) { - return appl; + return term -> m.apply(term).orElseGet(() -> { + switch(term.termTag()) { + case IApplTerm: { + IApplTerm appl = (IApplTerm) term; + final ImmutableList newArgs; + if((newArgs = Terms.applyLazy(appl.getArgs(), sometd(m))) == null) { + return appl; + } + return B.newAppl(appl.getOp(), newArgs, appl.getAttachments()); + } + + case IConsTerm: { + IConsTerm cons = (IConsTerm) term; + return B.newCons(sometd(m).apply(cons.getHead()), + (IListTerm) sometd(m).apply(cons.getTail()), cons.getAttachments()); + } + + case INilTerm: + case ITermVar: + case IBlobTerm: + case IIntTerm: + case IStringTerm: { + return term; } - return B.newAppl(appl.getOp(), newArgs, appl.getAttachments()); - }, - (list) -> list.match(ListTerms. cases( - (cons) -> B.newCons(sometd(m).apply(cons.getHead()), (IListTerm) sometd(m).apply(cons.getTail()), cons.getAttachments()), - (nil) -> nil, - (var) -> var - )), - (string) -> string, - (integer) -> integer, - (blob) -> blob, - (var) -> var - ))); - // @formatter:on + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for ITerm subclass/tag"); + }); } public static Function1 somebu(PartialFunction1 m) { return term -> { - // @formatter:off - ITerm next = term.match(Terms.cases( - (appl) -> { + ITerm next = null; + switch(term.termTag()) { + case IApplTerm: { + IApplTerm appl = (IApplTerm) term; final ImmutableList newArgs; - if((newArgs = Terms.applyLazy(appl.getArgs(), somebu(m)::apply)) == null) { - return appl; + if((newArgs = Terms.applyLazy(appl.getArgs(), somebu(m))) + == null) { + next = appl; + break; } - return B.newAppl(appl.getOp(), newArgs, appl.getAttachments()); - }, - (list) -> list.match(ListTerms. cases( - (cons) -> B.newCons(somebu(m).apply(cons.getHead()), (IListTerm) somebu(m).apply(cons.getTail()), cons.getAttachments()), - (nil) -> nil, - (var) -> var - )), - (string) -> string, - (integer) -> integer, - (blob) -> blob, - (var) -> var - )); - // @formatter:on + next = B.newAppl(appl.getOp(), newArgs, appl.getAttachments()); + break; + } + + case IConsTerm: { + IConsTerm cons = (IConsTerm) term; + next = B.newCons(somebu(m).apply(cons.getHead()), + (IListTerm) somebu(m).apply(cons.getTail()), + cons.getAttachments()); + break; + } + + case INilTerm: + case ITermVar: + case IBlobTerm: + case IIntTerm: + case IStringTerm: { + next = term; + break; + } + } return m.apply(next).orElse(next); }; } - public static Function1> collecttd(PartialFunction1 m) { + public static Function1> collecttd( + PartialFunction1 m) { return term -> { - List results = Lists.newArrayList(); - M.casesFix(f -> Iterables2.>from( - // @formatter:off - (t, u) -> m.apply(t).map(r -> { + List results = new ArrayList<>(); + M.casesFix( + f -> Iterables2.from((t, u) -> m.apply(t).map(r -> { results.add(r); return Unit.unit; - }), - (t, u) -> Optional.of(t.match(Terms.cases( - (appl) -> { - for(ITerm arg : appl.getArgs()) { - f.match(arg, u); + }), (t, u) -> { + switch(t.termTag()) { + case IApplTerm: { + IApplTerm appl = (IApplTerm) t; + for(ITerm arg : appl.getArgs()) { + f.match(arg, u); + } + return Optional.of(Unit.unit); } - return Unit.unit; - }, - (list) -> list.match(ListTerms. cases( - (cons) -> { + + case IConsTerm: { + IConsTerm cons = (IConsTerm) t; f.match(cons.getHead(), u); f.match(cons.getTail(), u); - return Unit.unit; - }, - (nil) -> Unit.unit, - (var) -> Unit.unit - )), - (string) -> Unit.unit, - (integer) -> Unit.unit, - (blob) -> Unit.unit, - (var) -> Unit.unit - ))) - // @formatter:on - )).match(term); + return Optional.of(Unit.unit); + } + + case INilTerm: + case IStringTerm: + case IIntTerm: + case IBlobTerm: + case ITermVar: { + return Optional.of(Unit.unit); + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for ITerm subclass/tag"); + })).match(term); return results; }; } diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/stratego/ATermIndex.java b/nabl2.terms/src/main/java/mb/nabl2/terms/stratego/ATermIndex.java index 48e670c6a..f813c0f33 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/stratego/ATermIndex.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/stratego/ATermIndex.java @@ -1,8 +1,5 @@ package mb.nabl2.terms.stratego; -import static mb.nabl2.terms.build.TermBuild.B; -import static mb.nabl2.terms.matching.TermMatch.M; - import java.util.Iterator; import java.util.List; import java.util.Objects; @@ -15,12 +12,21 @@ import mb.nabl2.terms.IApplTerm; import mb.nabl2.terms.IAttachments; +import mb.nabl2.terms.IBlobTerm; +import mb.nabl2.terms.IConsTerm; +import mb.nabl2.terms.IIntTerm; +import mb.nabl2.terms.IListTerm; +import mb.nabl2.terms.INilTerm; +import mb.nabl2.terms.IStringTerm; import mb.nabl2.terms.ITerm; -import mb.nabl2.terms.ListTerms; +import mb.nabl2.terms.ITermVar; import mb.nabl2.terms.Terms; import mb.nabl2.terms.build.AbstractApplTerm; import mb.nabl2.terms.matching.TermMatch.IMatcher; +import static mb.nabl2.terms.build.TermBuild.B; +import static mb.nabl2.terms.matching.TermMatch.M; + @Value.Immutable(lazyhash = false) @Serial.Version(value = 42L) public abstract class ATermIndex extends AbstractApplTerm implements ITermIndex, IApplTerm { @@ -108,16 +114,27 @@ public static Optional get(IAttachments attachments) { * @return The {@link TermIndex} of the first eligible term. */ public static Optional find(ITerm term) { - // @formatter:off - return get(term.getAttachments()).map(Optional::of).orElseGet(() -> - term.match(Terms.>cases() - .appl(appl -> find(appl.getArgs().iterator())) - .list(list -> list.match(ListTerms.>cases() - .cons(cons -> find(cons.getHead()).map(Optional::of).orElseGet(() -> find(cons.getTail()))) - .otherwise(__ -> Optional.empty()))) - .otherwise(__ -> Optional.empty())) - ); - // @formatter:on + return get(term.getAttachments()).map(Optional::of).orElseGet(() -> { + switch(term.termTag()) { + case IApplTerm: { IApplTerm appl = (IApplTerm) term; + return find(appl.getArgs().iterator()); + } + + case IConsTerm: { IConsTerm cons = (IConsTerm) term; + return find(cons.getHead()).map(Optional::of).orElseGet(() -> find(cons.getTail())); + } + + case INilTerm: + case IStringTerm: + case IIntTerm: + case IBlobTerm: + case ITermVar: { + return Optional.empty(); + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for ITerm subclass/tag"); + }); } private static Optional find(Iterator termIterator) { diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/stratego/StrategoPlaceholders.java b/nabl2.terms/src/main/java/mb/nabl2/terms/stratego/StrategoPlaceholders.java index 6efd04aa0..a1f0c3e26 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/stratego/StrategoPlaceholders.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/stratego/StrategoPlaceholders.java @@ -13,11 +13,14 @@ import mb.nabl2.terms.IApplTerm; import mb.nabl2.terms.IAttachments; +import mb.nabl2.terms.IBlobTerm; +import mb.nabl2.terms.IConsTerm; +import mb.nabl2.terms.IIntTerm; import mb.nabl2.terms.IListTerm; +import mb.nabl2.terms.INilTerm; +import mb.nabl2.terms.IStringTerm; import mb.nabl2.terms.ITerm; import mb.nabl2.terms.ITermVar; -import mb.nabl2.terms.ListTerms; -import mb.nabl2.terms.Terms; import mb.nabl2.terms.build.TermBuild; /** @@ -38,22 +41,29 @@ private StrategoPlaceholders() {} * otherwise, {@code false} */ public static boolean isLiteralVar(ITerm term) { - return term.match(Terms.casesFix( - (m, appl) -> { - if (!isInjectionConstructor(appl)) return false; + // Injection + switch(term.termTag()) { + case IApplTerm: { IApplTerm appl = (IApplTerm) term; + if(!StrategoPlaceholders.isInjectionConstructor(appl)) + return false; // Injection return isLiteralVar(appl.getArgs().get(0)); - }, - (m, list) -> list.match(ListTerms.casesFix( - (lm, cons) -> false, - (lm, nil) -> false, - (lm, var) -> false - )), - (m, string) -> false, - (m, integer) -> false, - (m, blob) -> false, - (m, var) -> isLiteralSort(getSortFromAttachments(var.getAttachments())) - )); + } + + case ITermVar: { ITermVar var = (ITermVar) term; + return isLiteralSort(getSortFromAttachments(var.getAttachments())); + } + + case IConsTerm: + case INilTerm: + case IStringTerm: + case IIntTerm: + case IBlobTerm: { + return false; + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for ITerm subclass/tag"); } /** @@ -64,25 +74,35 @@ public static boolean isLiteralVar(ITerm term) { * otherwise, {@code false} */ public static boolean containsLiteralVar(ITerm term) { - return term.match(Terms.casesFix( - (m, appl) -> { - if (isPlaceholder(appl)) { + // Placeholder + switch(term.termTag()) { + case IApplTerm: { IApplTerm appl = (IApplTerm) term; + if(StrategoPlaceholders.isPlaceholder(appl)) { // Placeholder return isLiteralSort(getSortFromAttachments(appl.getAttachments())); } else { - return appl.getArgs().stream().anyMatch(a -> a.match(m)); + return appl.getArgs().stream().anyMatch( + StrategoPlaceholders::containsLiteralVar); } - }, - (m, list) -> list.match(ListTerms.casesFix( - (lm, cons) -> cons.getHead().match(m) || cons.getTail().match(lm), - (lm, nil) -> false, - (lm, var) -> isLiteralSort(getSortFromAttachments(var.getAttachments())) - )), - (m, string) -> false, - (m, integer) -> false, - (m, blob) -> false, - (m, var) -> isLiteralSort(getSortFromAttachments(var.getAttachments())) - )); + } + + case IConsTerm: { IConsTerm cons = (IConsTerm) term; + return containsLiteralVar(cons.getHead()) || containsLiteralVar(cons.getTail()); + } + + case ITermVar: { ITermVar var = (ITermVar) term; + return isLiteralSort(getSortFromAttachments(var.getAttachments())); + } + + case INilTerm: + case IStringTerm: + case IIntTerm: + case IBlobTerm: { + return false; + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for ITerm subclass/tag"); } /** @@ -189,25 +209,38 @@ public static boolean isPlaceholder(ITerm term) { * @return the term with its placeholders replaced */ public static ITerm replacePlaceholdersByVariables(ITerm term, PlaceholderVarMap placeholderVarMap) { - return term.match(Terms.casesFix( - (m, appl) -> { - if (isPlaceholder(appl)) { + // Placeholder + switch(term.termTag()) { + case IApplTerm: { + IApplTerm appl = (IApplTerm) term; + if(StrategoPlaceholders.isPlaceholder(appl)) { // Placeholder return placeholderVarMap.addPlaceholderMapping(appl); } else { - return TermBuild.B.newAppl(appl.getOp(), appl.getArgs().stream().map(a -> a.match(m)).collect(Collectors.toList()), appl.getAttachments()); + return TermBuild.B.newAppl(appl.getOp(), appl.getArgs().stream() + .map(a -> replacePlaceholdersByVariables(a, placeholderVarMap)) + .collect(Collectors.toList()), appl.getAttachments()); } - }, - (m, list) -> list.match(ListTerms.casesFix( - (lm, cons) -> TermBuild.B.newCons(cons.getHead().match(m), cons.getTail().match(lm), cons.getAttachments()), - (lm, nil) -> nil, - (lm, var) -> var - )), - (m, string) -> string, - (m, integer) -> integer, - (m, blob) -> blob, - (m, var) -> var - )); + } + + case IConsTerm: { + IConsTerm cons = (IConsTerm) term; + return TermBuild.B.newCons( + replacePlaceholdersByVariables(cons.getHead(), placeholderVarMap), + (IListTerm) replacePlaceholdersByVariables(cons.getTail(), placeholderVarMap), + cons.getAttachments()); + } + + case INilTerm: + case IStringTerm: + case IIntTerm: + case IBlobTerm: + case ITermVar: { + return term; + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for ITerm subclass/tag"); } /** @@ -218,30 +251,59 @@ public static ITerm replacePlaceholdersByVariables(ITerm term, PlaceholderVarMap * @return the term with its term variables replaced */ public static ITerm replaceVariablesByPlaceholders(ITerm term, PlaceholderVarMap placeholderVarMap) { - return term.match(Terms.cases( - appl -> { - if (isInjectionConstructor(appl) && onlyInjectionConstructorsAndVariables(appl)) { + // TODO: Ability to relate placeholders, such that typing in the editor in one placeholder also types in another + switch(term.termTag()) { + case IApplTerm: { IApplTerm appl = (IApplTerm) term; + if(StrategoPlaceholders.isInjectionConstructor(appl) + && onlyInjectionConstructorsAndVariables(appl)) { return getPlaceholderForTerm(appl); } else { - return TermBuild.B.newAppl(appl.getOp(), appl.getArgs().stream().map(a -> replaceVariablesByPlaceholders(a, placeholderVarMap)).collect(Collectors.toList()), appl.getAttachments()); + return TermBuild.B.newAppl(appl.getOp(), appl.getArgs().stream() + .map(a -> replaceVariablesByPlaceholders(a, placeholderVarMap)) + .collect(Collectors.toList()), appl.getAttachments()); } - }, - list -> replaceVariablesByPlaceholdersInList(list, placeholderVarMap), - string -> string, - integer -> integer, - blob -> blob, - // TODO: Ability to relate placeholders, such that typing in the editor in one placeholder also types in another - var -> getPlaceholderForVar(var, placeholderVarMap) - )); + } + + case IConsTerm: + case INilTerm: { IListTerm list = (IListTerm) term; + return replaceVariablesByPlaceholdersInList(list, placeholderVarMap); + } + + case ITermVar: { ITermVar var = (ITermVar) term; + return getPlaceholderForVar(var, placeholderVarMap); + } + + case IStringTerm: + case IIntTerm: + case IBlobTerm: { + return term; + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for ITerm subclass/tag"); } public static IListTerm replaceVariablesByPlaceholdersInList(IListTerm term, PlaceholderVarMap placeholderVarMap) { - return term.match(ListTerms.cases( - cons -> TermBuild.B.newCons(replaceVariablesByPlaceholders(cons.getHead(), placeholderVarMap), replaceVariablesByPlaceholdersInList(cons.getTail(), placeholderVarMap), cons.getAttachments()), - nil -> nil, - var -> TermBuild.B.newCons(replaceVariablesByPlaceholders(var, placeholderVarMap), TermBuild.B.newNil()) - //var -> TermBuild.B.newNil()// var // FIXME: Should be make a placeholder for list tails? - )); + //var -> TermBuild.B.newNil()// var // FIXME: Should be make a placeholder for list tails? + switch(term.listTermTag()) { + case IConsTerm: { IConsTerm cons = (IConsTerm) term; + return TermBuild.B.newCons( + replaceVariablesByPlaceholders(cons.getHead(), placeholderVarMap), + replaceVariablesByPlaceholdersInList(cons.getTail(), placeholderVarMap), + cons.getAttachments()); + } + + case ITermVar: { ITermVar var = (ITermVar) term; + return TermBuild.B.newCons(replaceVariablesByPlaceholders(var, placeholderVarMap), + TermBuild.B.newNil()); + } + + case INilTerm: { + return term; + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for IListTerm subclass/tag"); } /** @@ -253,22 +315,46 @@ public static IListTerm replaceVariablesByPlaceholdersInList(IListTerm term, Pla * @return the term with its term variables replaced */ public static ITerm replaceListVariablesByEmptyList(ITerm term) { - return term.match(Terms.cases( - appl -> TermBuild.B.newAppl(appl.getOp(), appl.getArgs().stream().map(StrategoPlaceholders::replaceListVariablesByEmptyList).collect(Collectors.toList()), appl.getAttachments()), - list -> replaceListVariablesByEmptyListInList(list), - string -> string, - integer -> integer, - blob -> blob, - var -> var - )); + switch(term.termTag()) { + case IApplTerm: { IApplTerm appl = (IApplTerm) term; + return TermBuild.B.newAppl(appl.getOp(), appl.getArgs().stream().map(StrategoPlaceholders::replaceListVariablesByEmptyList) + .collect(Collectors.toList()), appl.getAttachments()); + } + + case IConsTerm: + case INilTerm: { IListTerm list = (IListTerm) term; + return replaceListVariablesByEmptyListInList(list); + } + + case IStringTerm: + case IIntTerm: + case IBlobTerm: + case ITermVar: { + return term; + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for ITerm subclass/tag"); } public static IListTerm replaceListVariablesByEmptyListInList(IListTerm term) { - return term.match(ListTerms.cases( - cons -> TermBuild.B.newCons(replaceListVariablesByEmptyList(cons.getHead()), replaceListVariablesByEmptyListInList(cons.getTail()), cons.getAttachments()), - nil -> nil, - var -> TermBuild.B.newNil() - )); + switch(term.listTermTag()) { + case IConsTerm: { + IConsTerm cons = (IConsTerm) term; + return TermBuild.B.newCons(replaceListVariablesByEmptyList(cons.getHead()), + replaceListVariablesByEmptyListInList(cons.getTail()), cons.getAttachments()); + } + + case INilTerm: { + return term; + } + + case ITermVar: { + return TermBuild.B.newNil(); + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for IListTerm subclass/tag"); } public static boolean isInjectionConstructor(ITerm term) { @@ -314,20 +400,30 @@ public static ITermVar getInjectionArgument(IApplTerm appl) { } public static boolean onlyInjectionConstructorsAndVariables(ITerm term) { - return term.match(Terms.cases( - appl -> { - if (!isInjectionConstructor(appl)) { + switch(term.termTag()) { + case IApplTerm: { IApplTerm appl = (IApplTerm) term; + if(!StrategoPlaceholders.isInjectionConstructor(appl)) { return false; } else { - return appl.getArgs().stream().allMatch(StrategoPlaceholders::onlyInjectionConstructorsAndVariables); + return appl.getArgs().stream() + .allMatch(StrategoPlaceholders::onlyInjectionConstructorsAndVariables); } - }, - list -> false, - string -> false, - integer -> false, - blob -> false, - var -> true - )); + } + + case IConsTerm: + case INilTerm: + case IStringTerm: + case IIntTerm: + case IBlobTerm: { + return false; + } + + case ITermVar: { + return true; + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for ITerm subclass/tag"); } /** diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/stratego/StrategoTermIndices.java b/nabl2.terms/src/main/java/mb/nabl2/terms/stratego/StrategoTermIndices.java index 32a416d18..0f17e95d2 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/stratego/StrategoTermIndices.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/stratego/StrategoTermIndices.java @@ -5,10 +5,17 @@ import java.util.List; import java.util.Optional; +import org.spoofax.interpreter.terms.IStrategoAppl; import org.spoofax.interpreter.terms.IStrategoConstructor; +import org.spoofax.interpreter.terms.IStrategoInt; import org.spoofax.interpreter.terms.IStrategoList; +import org.spoofax.interpreter.terms.IStrategoPlaceholder; +import org.spoofax.interpreter.terms.IStrategoReal; +import org.spoofax.interpreter.terms.IStrategoString; import org.spoofax.interpreter.terms.IStrategoTerm; +import org.spoofax.interpreter.terms.IStrategoTuple; import org.spoofax.interpreter.terms.ITermFactory; +import org.spoofax.terms.StrategoTerm; import org.spoofax.terms.TermFactory; import org.spoofax.terms.util.TermUtils; @@ -41,21 +48,62 @@ private static class Indexer { } private IStrategoTerm index(final IStrategoTerm term) { - // @formatter:off - IStrategoTerm result = StrategoTerms.match(term, - StrategoTerms.cases( - appl -> termFactory.makeAppl(appl.getConstructor(), index(appl.getAllSubterms()), appl.getAnnotations()), - tuple -> termFactory.makeTuple(index(tuple.getAllSubterms()), tuple.getAnnotations()), - list -> index(list), - integer -> termFactory.annotateTerm(termFactory.makeInt(integer.intValue()), integer.getAnnotations()), - real -> termFactory.annotateTerm(termFactory.makeReal(real.realValue()), real.getAnnotations()), - string -> termFactory.annotateTerm(termFactory.makeString(string.stringValue()), string.getAnnotations()), - blob -> new StrategoBlob(blob.value()), - plhdr -> termFactory.annotateTerm(termFactory.makePlaceholder(plhdr.getTemplate()), plhdr.getAnnotations()) - )); - // @formatter:on + IStrategoTerm result; + switch(term.getType()) { + case APPL: { + IStrategoAppl appl = (IStrategoAppl) term; + result = + termFactory.makeAppl(appl.getConstructor(), index(appl.getAllSubterms()), + appl.getAnnotations()); + break; + } + case TUPLE: { + IStrategoTuple tuple = (IStrategoTuple) term; + result = termFactory.makeTuple(index(tuple.getAllSubterms()), + tuple.getAnnotations()); + break; + } + case LIST: { + result = index((IStrategoList) term); + break; + } + case INT: { + IStrategoInt integer = (IStrategoInt) term; + result = termFactory.annotateTerm(termFactory.makeInt(integer.intValue()), + integer.getAnnotations()); + break; + } + case REAL: { + IStrategoReal real = (IStrategoReal) term; + result = termFactory.annotateTerm(termFactory.makeReal(real.realValue()), + real.getAnnotations()); + break; + } + case STRING: { + IStrategoString string = (IStrategoString) term; + result = termFactory.annotateTerm(termFactory.makeString(string.stringValue()), + string.getAnnotations()); + break; + } + case BLOB: { + StrategoBlob blob = (StrategoBlob) term; + result = new StrategoBlob(blob.value()); + break; + } + case PLACEHOLDER: { + IStrategoPlaceholder plhdr = (IStrategoPlaceholder) term; + result = + termFactory.annotateTerm(termFactory.makePlaceholder(plhdr.getTemplate()), + plhdr.getAnnotations()); + break; + } + default: { + throw new IllegalArgumentException( + "Unsupported Stratego term type " + term.getType()); + } + } final TermIndex index1 = TermIndex.of(resource, ++currentId); - final TermIndex index2 = (TermIndex) TermOrigin.get(term).map(o -> o.put(index1)).orElse(index1); + final TermIndex index2 = TermOrigin.get(term).map(o -> o.put(index1)).orElse(index1); result = put(index2, result, termFactory); termFactory.copyAttachments(term, result); return result; @@ -70,8 +118,8 @@ private IStrategoList index(final IStrategoList list) { } termFactory.copyAttachments(list, result); final TermIndex index1 = TermIndex.of(resource, ++currentId); - final TermIndex index2 = (TermIndex) TermOrigin.get(list).map(o -> o.put(index1)).orElse(index1); - result = (IStrategoList) put(index2, result, termFactory); + final TermIndex index2 = TermOrigin.get(list).map(o -> o.put(index1)).orElse(index1); + result = put(index2, result, termFactory); return result; } @@ -97,19 +145,60 @@ private static class Eraser { } private IStrategoTerm erase(final IStrategoTerm term) { - IStrategoTerm result = StrategoTerms.match(term, StrategoTerms.cases( - // @formatter:off - appl -> termFactory.makeAppl(appl.getConstructor(), erase(appl.getAllSubterms()), - appl.getAnnotations()), - tuple -> termFactory.makeTuple(erase(tuple.getAllSubterms()), tuple.getAnnotations()), - list -> erase(list), - integer -> termFactory.annotateTerm(termFactory.makeInt(integer.intValue()), integer.getAnnotations()), - real -> termFactory.annotateTerm(termFactory.makeReal(real.realValue()), real.getAnnotations()), - string -> termFactory.annotateTerm(termFactory.makeString(string.stringValue()), string.getAnnotations()), - blob -> new StrategoBlob(blob.value()), - plhdr -> termFactory.annotateTerm(termFactory.makePlaceholder(plhdr.getTemplate()), plhdr.getAnnotations()) - // @formatter:on - )); + IStrategoTerm result; + switch(term.getType()) { + case APPL: { + IStrategoAppl appl = (IStrategoAppl) term; + result = + termFactory.makeAppl(appl.getConstructor(), erase(appl.getAllSubterms()), + appl.getAnnotations()); + break; + } + case TUPLE: { + IStrategoTuple tuple = (IStrategoTuple) term; + result = termFactory.makeTuple(erase(tuple.getAllSubterms()), + tuple.getAnnotations()); + break; + } + case LIST: { + result = erase((IStrategoList) term); + break; + } + case INT: { + IStrategoInt integer = (IStrategoInt) term; + result = termFactory.annotateTerm(termFactory.makeInt(integer.intValue()), + integer.getAnnotations()); + break; + } + case REAL: { + IStrategoReal real = (IStrategoReal) term; + result = termFactory.annotateTerm(termFactory.makeReal(real.realValue()), + real.getAnnotations()); + break; + } + case STRING: { + IStrategoString string = (IStrategoString) term; + result = termFactory.annotateTerm(termFactory.makeString(string.stringValue()), + string.getAnnotations()); + break; + } + case BLOB: { + StrategoBlob blob = (StrategoBlob) term; + result = new StrategoBlob(blob.value()); + break; + } + case PLACEHOLDER: { + IStrategoPlaceholder plhdr = (IStrategoPlaceholder) term; + result = + termFactory.annotateTerm(termFactory.makePlaceholder(plhdr.getTemplate()), + plhdr.getAnnotations()); + break; + } + default: { + throw new IllegalArgumentException( + "Unsupported Stratego term type " + term.getType()); + } + } termFactory.copyAttachments(term, result); result = remove(result, termFactory); assert !get(result).isPresent(); @@ -124,7 +213,7 @@ private IStrategoList erase(final IStrategoList list) { result = termFactory.makeListCons(erase(list.head()), erase(list.tail()), list.getAnnotations()); } termFactory.copyAttachments(list, result); - result = (IStrategoList) remove(result, termFactory); + result = remove(result, termFactory); assert !get(result).isPresent(); return result; } @@ -177,7 +266,7 @@ public static Optional match(IStrategoTerm term) { IStrategoTerm idTerm = term.getSubterm(1); final TermIndex index1 = TermIndex.of(TermUtils.toJavaString(resourceTerm), TermUtils.toJavaInt(idTerm)); - final TermIndex index2 = (TermIndex) TermOrigin.get(term).map(o -> o.put(index1)).orElse(index1); + final TermIndex index2 = TermOrigin.get(term).map(o -> o.put(index1)).orElse(index1); return Optional.of(index2); } diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/stratego/StrategoTerms.java b/nabl2.terms/src/main/java/mb/nabl2/terms/stratego/StrategoTerms.java index 2b66bbbf2..b1d24f81f 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/stratego/StrategoTerms.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/stratego/StrategoTerms.java @@ -1,8 +1,9 @@ package mb.nabl2.terms.stratego; -import static mb.nabl2.terms.build.TermBuild.B; - +import java.util.ArrayDeque; +import java.util.Deque; import java.util.LinkedList; +import java.util.List; import java.util.Optional; import javax.annotation.Nullable; @@ -22,14 +23,21 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; +import mb.nabl2.terms.IApplTerm; import mb.nabl2.terms.IAttachments; +import mb.nabl2.terms.IBlobTerm; +import mb.nabl2.terms.IConsTerm; +import mb.nabl2.terms.IIntTerm; import mb.nabl2.terms.IListTerm; +import mb.nabl2.terms.IStringTerm; import mb.nabl2.terms.ITerm; -import mb.nabl2.terms.ListTerms; +import mb.nabl2.terms.ITermVar; import mb.nabl2.terms.Terms; import mb.nabl2.terms.build.Attachments; import mb.nabl2.terms.matching.VarProvider; +import static mb.nabl2.terms.build.TermBuild.B; + public class StrategoTerms { private final org.spoofax.interpreter.terms.ITermFactory termFactory; @@ -45,27 +53,55 @@ public IStrategoTerm toStratego(ITerm term) { } public IStrategoTerm toStratego(ITerm term, boolean varsToPlhdrs) { - // @formatter:off - IStrategoTerm strategoTerm = term.match(Terms.cases( - appl -> { + IStrategoTerm strategoTerm = null; + switch(term.termTag()) { + case IApplTerm: { + IApplTerm appl = (IApplTerm) term; + List args = appl.getArgs(); IStrategoTerm[] argArray = appl.getArgs().stream().map(arg -> toStratego(arg, varsToPlhdrs)).toArray(IStrategoTerm[]::new); - return appl.getOp().equals(Terms.TUPLE_OP) - ? termFactory.makeTuple(argArray) - : termFactory.makeAppl(termFactory.makeConstructor(appl.getOp(), appl.getArity()), argArray); - }, - list -> toStrategoList(list, varsToPlhdrs), - string -> termFactory.makeString(string.getValue()), - integer -> termFactory.makeInt(integer.getValue()), - blob -> new StrategoBlob(blob.getValue()), - var -> { - if (varsToPlhdrs) { - return termFactory.makePlaceholder(termFactory.makeTuple(termFactory.makeString(var.getResource()), termFactory.makeString(var.getName()))); + strategoTerm = appl.getOp().equals(Terms.TUPLE_OP) ? termFactory.makeTuple(argArray) : + termFactory.makeAppl(termFactory.makeConstructor(appl.getOp(), appl.getArity()), argArray); + break; + } + + case IConsTerm: + case INilTerm: { + IListTerm list = (IListTerm) term; + strategoTerm = toStrategoList(list, varsToPlhdrs); + break; + } + + case IStringTerm: { + IStringTerm string = (IStringTerm) term; + strategoTerm = termFactory.makeString(string.getValue()); + break; + } + + case IIntTerm: { + IIntTerm integer = (IIntTerm) term; + strategoTerm = termFactory.makeInt(integer.getValue()); + break; + } + + case IBlobTerm: { + IBlobTerm blob = (IBlobTerm) term; + strategoTerm = new StrategoBlob(blob.getValue()); + break; + } + + case ITermVar: { + ITermVar var = (ITermVar) term; + if(varsToPlhdrs) { + strategoTerm = termFactory.makePlaceholder(termFactory.makeTuple(termFactory.makeString(var.getResource()), + termFactory.makeString(var.getName()))); } else { - return termFactory.makeAppl("nabl2.Var", termFactory.makeString(var.getResource()), termFactory.makeString(var.getName())); + strategoTerm = termFactory.makeAppl("nabl2.Var", termFactory.makeString(var.getResource()), + termFactory.makeString(var.getName())); } + break; } - )); - // @formatter:on + } + assert strategoTerm != null; switch(strategoTerm.getType()) { case BLOB: case LIST: @@ -77,24 +113,28 @@ public IStrategoTerm toStratego(ITerm term, boolean varsToPlhdrs) { } private IStrategoTerm toStrategoList(IListTerm list, boolean varsToPlhdrs) { - final LinkedList terms = Lists.newLinkedList(); - final LinkedList attachments = Lists.newLinkedList(); + final Deque terms = new ArrayDeque<>(list.getMinSize()); + final Deque attachments = new ArrayDeque<>(list.getMinSize()); while(list != null) { attachments.push(list.getAttachments()); - // @formatter:off - list = list.match(ListTerms.cases( - cons -> { + switch(list.listTermTag()) { + case IConsTerm: { + IConsTerm cons = (IConsTerm) list; terms.push(toStratego(cons.getHead(), varsToPlhdrs)); - return cons.getTail(); - }, - nil -> { - return null; - }, - var -> { - throw new IllegalArgumentException("Cannot convert specialized terms to Stratego."); + list = cons.getTail(); + break; + } + + case INilTerm: { + list = null; + break; + } + + case ITermVar: { + throw new IllegalArgumentException( + "Cannot convert specialized terms to Stratego."); } - )); - // @formatter:on + } } IStrategoList strategoList = termFactory.makeList(); putAttachments(strategoList, attachments.pop()); @@ -136,41 +176,65 @@ public ITerm fromStratego(IStrategoTerm sterm) { return fromStratego(sterm, null); } - public ITerm fromStratego(IStrategoTerm sterm, @Nullable VarProvider varProvider) { - @Nullable IAttachments attachments = getAttachments(sterm); - // @formatter:off - ITerm term = match(sterm, StrategoTerms.cases( - appl -> { + public ITerm fromStratego(IStrategoTerm term, @Nullable VarProvider varProvider) { + @Nullable IAttachments attachments = getAttachments(term); + ITerm result; + switch(term.getType()) { + case APPL: { + IStrategoAppl appl = (IStrategoAppl) term; final IStrategoTerm[] subTerms = appl.getAllSubterms(); final ImmutableList.Builder args = ImmutableList.builderWithExpectedSize(subTerms.length); for(IStrategoTerm subTerm : subTerms) { args.add(fromStratego(subTerm, varProvider)); } - return B.newAppl(appl.getConstructor().getName(), args.build(), attachments); - }, - tuple -> { + result = B.newAppl(appl.getConstructor().getName(), args.build(), attachments); + break; + } + case TUPLE: { + IStrategoTuple tuple = (IStrategoTuple) term; final IStrategoTerm[] subTerms = tuple.getAllSubterms(); final ImmutableList.Builder args = ImmutableList.builderWithExpectedSize(subTerms.length); for(IStrategoTerm subTerm : subTerms) { args.add(fromStratego(subTerm, varProvider)); } - return B.newTuple(args.build(), attachments); - }, - list -> fromStrategoList(list, varProvider), - integer -> B.newInt(integer.intValue(), attachments), - real -> { throw new IllegalArgumentException("Real values are not supported."); }, - string -> B.newString(string.stringValue(), attachments), - blob -> B.newBlob(blob.value()), - plhdr -> { - if (varProvider != null) { - return varProvider.freshWld(); + result = B.newTuple(args.build(), attachments); + break; + } + case LIST: { + result = fromStrategoList((IStrategoList) term, varProvider); + break; + } + case INT: { + IStrategoInt integer = (IStrategoInt) term; + result = B.newInt(integer.intValue(), attachments); + break; + } + case REAL: { + throw new IllegalArgumentException("Real values are not supported."); + } + case STRING: { + IStrategoString string = (IStrategoString) term; + result = B.newString(string.stringValue(), attachments); + break; + } + case BLOB: { + StrategoBlob blob = (StrategoBlob) term; + result = B.newBlob(blob.value()); + break; + } + case PLACEHOLDER: { + if(varProvider != null) { + result = varProvider.freshWld(); } else { throw new IllegalArgumentException("Placeholders are not supported."); } + break; } - )); - // @formatter:on - return term; + default: { + throw new IllegalArgumentException("Unsupported Stratego term type " + term.getType()); + } + } + return result; } private IListTerm fromStrategoList(IStrategoList list, @Nullable VarProvider varProvider) { @@ -188,13 +252,9 @@ private IListTerm fromStrategoList(IStrategoList list, @Nullable VarProvider var public static IAttachments getAttachments(IStrategoTerm term) { final Attachments.Builder b = Attachments.Builder.of(); - TermOrigin.get(term).ifPresent(origin -> { - b.put(TermOrigin.class, origin); - }); + TermOrigin.get(term).ifPresent(origin -> b.put(TermOrigin.class, origin)); - StrategoTermIndices.get(term).ifPresent(termIndex -> { - b.put(TermIndex.class, termIndex); - }); + StrategoTermIndices.get(term).ifPresent(termIndex -> b.put(TermIndex.class, termIndex)); final IStrategoList annos = term.getAnnotations(); if(!annos.isEmpty()) { diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/substitution/FreshVars.java b/nabl2.terms/src/main/java/mb/nabl2/terms/substitution/FreshVars.java index 1d9215ecb..e34644a08 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/substitution/FreshVars.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/substitution/FreshVars.java @@ -3,6 +3,7 @@ import static mb.nabl2.terms.build.TermBuild.B; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import org.metaborg.util.collection.CapsuleUtil; @@ -30,7 +31,7 @@ public FreshVars() { } @SafeVarargs public FreshVars(java.util.Set... preExistingVarSets) { - this.oldVarSets = Lists.newArrayList(preExistingVarSets); + this.oldVarSets = new ArrayList<>(Arrays.asList(preExistingVarSets)); this.oldVars = CapsuleUtil.immutableSet(); this.newVars = CapsuleUtil.immutableSet(); } diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/substitution/PersistentSubstitution.java b/nabl2.terms/src/main/java/mb/nabl2/terms/substitution/PersistentSubstitution.java index f7c2a4478..65fe95b6f 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/substitution/PersistentSubstitution.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/substitution/PersistentSubstitution.java @@ -1,23 +1,25 @@ package mb.nabl2.terms.substitution; -import static mb.nabl2.terms.build.TermBuild.B; - import java.io.Serializable; import java.util.Map.Entry; import java.util.Set; import org.metaborg.util.collection.CapsuleUtil; import org.metaborg.util.collection.MultiSet; +import org.metaborg.util.functions.Function1; import com.google.common.collect.ImmutableList; import io.usethesource.capsule.Map; +import mb.nabl2.terms.IApplTerm; +import mb.nabl2.terms.IConsTerm; import mb.nabl2.terms.IListTerm; import mb.nabl2.terms.ITerm; import mb.nabl2.terms.ITermVar; -import mb.nabl2.terms.ListTerms; import mb.nabl2.terms.Terms; +import static mb.nabl2.terms.build.TermBuild.B; + public abstract class PersistentSubstitution implements ISubstitution { protected abstract Map subst(); @@ -48,42 +50,61 @@ public abstract class PersistentSubstitution implements ISubstitution { if(term.isGround()) { return term; } - // @formatter:off - return term.match(Terms.cases( - appl -> { + switch(term.termTag()) { + case IApplTerm: { + IApplTerm appl = (IApplTerm) term; final ImmutableList newArgs; - if((newArgs = Terms.applyLazy(appl.getArgs(), this::apply)) == null) { + if((newArgs = Terms.applyLazy(appl.getArgs(), this::apply)) + == null) { return appl; } return B.newAppl(appl.getOp(), newArgs, appl.getAttachments()); - }, - list -> apply(list), - string -> string, - integer -> integer, - blob -> blob, - var -> apply(var) - )); - // @formatter:on + } + + case IConsTerm: + case INilTerm: { + return apply((IListTerm) term); + } + + case ITermVar: { + return apply((ITermVar) term); + } + + case IStringTerm: + case IBlobTerm: + case IIntTerm: { + return term; + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for ITerm subclass/tag"); } private IListTerm apply(IListTerm list) { if(list.isGround()) { return list; } - // @formatter:off - return list.match(ListTerms.cases( - cons -> { + switch(list.listTermTag()) { + case IConsTerm: { + IConsTerm cons = (IConsTerm) list; final ITerm newHead = apply(cons.getHead()); final IListTerm newTail = apply(cons.getTail()); if(newHead == cons.getHead() && newTail == cons.getTail()) { return cons; } return B.newCons(newHead, newTail, cons.getAttachments()); - }, - nil -> nil, - var -> (IListTerm) apply(var) - )); - // @formatter:on + } + + case INilTerm: { + return list; + } + + case ITermVar: { + return (IListTerm) apply((ITermVar) list); + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for IListTerm subclass/tag"); } private ITerm apply(ITermVar var) { diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/substitution/Renaming.java b/nabl2.terms/src/main/java/mb/nabl2/terms/substitution/Renaming.java index 7ebe6722b..ae6208c1a 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/substitution/Renaming.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/substitution/Renaming.java @@ -1,23 +1,26 @@ package mb.nabl2.terms.substitution; -import static mb.nabl2.terms.build.TermBuild.B; - import java.util.Collections; import java.util.Map; import java.util.Map.Entry; import java.util.Set; import java.util.stream.Collectors; +import org.metaborg.util.functions.Function1; + import com.google.common.collect.BiMap; import com.google.common.collect.HashBiMap; import com.google.common.collect.ImmutableList; +import mb.nabl2.terms.IApplTerm; +import mb.nabl2.terms.IConsTerm; import mb.nabl2.terms.IListTerm; import mb.nabl2.terms.ITerm; import mb.nabl2.terms.ITermVar; -import mb.nabl2.terms.ListTerms; import mb.nabl2.terms.Terms; +import static mb.nabl2.terms.build.TermBuild.B; + public class Renaming implements IRenaming { private final BiMap renaming; @@ -47,32 +50,53 @@ private Renaming(BiMap renaming) { } @Override public ITerm apply(ITerm term) { - // @formatter:off - return term.match(Terms.cases( - appl -> { + switch(term.termTag()) { + case IApplTerm: { + IApplTerm appl = (IApplTerm) term; final ImmutableList newArgs; if((newArgs = Terms.applyLazy(appl.getArgs(), this::apply)) == null) { return appl; } return B.newAppl(appl.getOp(), newArgs, appl.getAttachments()); - }, - list -> apply(list), - string -> string, - integer -> integer, - blob -> blob, - var -> rename(var) - )); - // @formatter:on + } + + case IConsTerm: { + return apply((IListTerm) term); + } + + case ITermVar: { + return rename((ITermVar) term); + } + + case INilTerm: + case IStringTerm: + case IIntTerm: + case IBlobTerm: { + return term; + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for ITerm subclass/tag"); } private IListTerm apply(IListTerm list) { - // @formatter:off - return list.match(ListTerms.cases( - cons -> B.newCons(apply(cons.getHead()), apply(cons.getTail()), cons.getAttachments()), - nil -> nil, - var -> (IListTerm) rename(var) - )); - // @formatter:on + switch(list.listTermTag()) { + case IConsTerm: { + IConsTerm cons = (IConsTerm) list; + return B.newCons(apply(cons.getHead()), + apply(cons.getTail()), cons.getAttachments()); + } + + case INilTerm: { + return list; + } + + case ITermVar: { + return rename((ITermVar) list); + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for IListTerm subclass/tag"); } @Override public Map asMap() { diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/substitution/Replacement.java b/nabl2.terms/src/main/java/mb/nabl2/terms/substitution/Replacement.java index 25f44035e..2fe4bc598 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/substitution/Replacement.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/substitution/Replacement.java @@ -1,23 +1,28 @@ package mb.nabl2.terms.substitution; import java.util.Map.Entry; +import java.util.Set; import java.util.stream.Collectors; +import org.metaborg.util.functions.Function1; + import com.google.common.collect.BiMap; import com.google.common.collect.HashBiMap; import com.google.common.collect.ImmutableList; -import static mb.nabl2.terms.build.TermBuild.B; - -import java.util.Set; - import mb.nabl2.terms.IApplTerm; +import mb.nabl2.terms.IBlobTerm; +import mb.nabl2.terms.IConsTerm; +import mb.nabl2.terms.IIntTerm; import mb.nabl2.terms.IListTerm; +import mb.nabl2.terms.INilTerm; +import mb.nabl2.terms.IStringTerm; import mb.nabl2.terms.ITerm; import mb.nabl2.terms.ITermVar; -import mb.nabl2.terms.ListTerms; import mb.nabl2.terms.Terms; +import static mb.nabl2.terms.build.TermBuild.B; + public class Replacement implements IReplacement { private final BiMap replacement; @@ -51,26 +56,38 @@ private Replacement(BiMap replacement, boolean traverseSubTerms) { } @Override public ITerm apply(ITerm term) { - // @formatter:off - return term.match(Terms.cases( - appl -> { + // Cannot happen + switch(term.termTag()) { + case IApplTerm: { final IApplTerm newAppl = (IApplTerm) replace(term); if(!traverseSubTerms) { return newAppl; } final ImmutableList newArgs; - if((newArgs = Terms.applyLazy(newAppl.getArgs(), this::apply)) == null) { + if((newArgs = Terms.applyLazy(newAppl.getArgs(), this::apply)) + == null) { return newAppl; } - return B.newAppl(newAppl.getOp(), newArgs, appl.getAttachments()); - }, - list -> apply(list), - string -> replace(string), - integer -> replace(integer), - blob -> replace(blob), - var -> var // Cannot happen - )); - // @formatter:on + return B.newAppl(newAppl.getOp(), newArgs, term.getAttachments()); + } + + case IConsTerm: + case INilTerm: { + return apply((IListTerm) term); + } + + case IStringTerm: + case IIntTerm: + case IBlobTerm: { + return replace(term); + } + + case ITermVar: { + return term; + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for ITerm subclass/tag"); } private IListTerm apply(IListTerm list) { @@ -78,21 +95,19 @@ private IListTerm apply(IListTerm list) { if(!traverseSubTerms) { return newList; } - // @formatter:off - return newList.match(ListTerms.cases( - cons -> { - final ITerm newHead = apply(cons.getHead()); - final IListTerm newTail = apply(cons.getTail()); - - if(newHead != cons.getHead() || newTail != cons.getTail()) { - B.newCons(newHead, newTail, cons.getAttachments()); - } - return cons; - }, - nil -> nil, - var -> var // Cannot happen - )); - // @formatter:on + // Cannot happen + if(newList.listTermTag() == IListTerm.Tag.IConsTerm) { + IConsTerm cons = (IConsTerm) newList; + final ITerm newHead = apply(cons.getHead()); + final IListTerm newTail = apply(cons.getTail()); + + if(newHead != cons.getHead() || newTail != cons.getTail()) { + B.newCons(newHead, newTail, cons.getAttachments()); + } + return cons; + } else { + return newList; + } } diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/unification/RigidException.java b/nabl2.terms/src/main/java/mb/nabl2/terms/unification/RigidException.java index 3e2faf3bf..cf7a2c286 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/unification/RigidException.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/unification/RigidException.java @@ -17,9 +17,9 @@ public RigidException(ITermVar var) { this.vars = ImmutableSet.of(var); } - public RigidException(ITermVar var1, ITermVar var2) { + public RigidException(ITermVar var, ITermVar var2) { super("rigid", null, false, false); - this.vars = ImmutableSet.of(var1, var2); + this.vars = ImmutableSet.of(var, var2); } public RigidException(Iterable vars) { diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/unification/u/BaseUnifier.java b/nabl2.terms/src/main/java/mb/nabl2/terms/unification/u/BaseUnifier.java index 04179dbde..27d66b264 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/unification/u/BaseUnifier.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/unification/u/BaseUnifier.java @@ -1,9 +1,8 @@ package mb.nabl2.terms.unification.u; -import static mb.nabl2.terms.build.TermBuild.B; - import java.io.Serializable; import java.util.ArrayDeque; +import java.util.ArrayList; import java.util.Deque; import java.util.LinkedList; import java.util.List; @@ -22,17 +21,23 @@ import io.usethesource.capsule.Map; import io.usethesource.capsule.Set; +import mb.nabl2.terms.IApplTerm; +import mb.nabl2.terms.IBlobTerm; +import mb.nabl2.terms.IConsTerm; +import mb.nabl2.terms.IIntTerm; import mb.nabl2.terms.IListTerm; +import mb.nabl2.terms.INilTerm; +import mb.nabl2.terms.IStringTerm; import mb.nabl2.terms.ITerm; import mb.nabl2.terms.ITermVar; -import mb.nabl2.terms.ListTerms; -import mb.nabl2.terms.Terms; import mb.nabl2.terms.substitution.ISubstitution; import mb.nabl2.terms.unification.OccursException; import mb.nabl2.terms.unification.RigidException; import mb.nabl2.terms.unification.SpecializedTermFormatter; import mb.nabl2.terms.unification.TermSize; +import static mb.nabl2.terms.build.TermBuild.B; + public abstract class BaseUnifier implements IUnifier, Serializable { private static final long serialVersionUID = 42L; @@ -92,10 +97,14 @@ public abstract class BaseUnifier implements IUnifier, Serializable { } @Override public ITerm findTerm(ITerm term) { - return term.match(Terms.cases().var(var -> { + ITerm.Tag tag = term.termTag(); + if(tag == ITerm.Tag.ITermVar) { + ITermVar var = (ITermVar) term; final ITermVar rep = findRep(var); return terms().getOrDefault(rep, rep); - }).otherwise(t -> t)); + } else { + return term; + } } /////////////////////////////////////////// @@ -108,48 +117,74 @@ public abstract class BaseUnifier implements IUnifier, Serializable { private ITerm findTermRecursive(final ITerm term, final java.util.Set stack, final java.util.Map visited) { - return term.match(Terms.cases( - // @formatter:off - appl -> B.newAppl(appl.getOp(), findRecursiveTerms(appl.getArgs(), stack, visited), appl.getAttachments()), - list -> findListTermRecursive(list, stack, visited), - string -> string, - integer -> integer, - blob -> blob, - var -> findVarRecursive(var, stack, visited) - // @formatter:on - )); + switch(term.termTag()) { + case IApplTerm: { IApplTerm appl = (IApplTerm) term; + return B.newAppl(appl.getOp(), findRecursiveTerms(appl.getArgs(), stack, visited), + appl.getAttachments()); + } + + case IConsTerm: + case INilTerm: { IListTerm list = (IListTerm) term; + return findListTermRecursive(list, stack, visited); + } + + case ITermVar: { ITermVar var = (ITermVar) term; + return findVarRecursive(var, stack, visited); + } + + case IStringTerm: + case IIntTerm: + case IBlobTerm: { + return term; + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for ITerm subclass/tag"); } private IListTerm findListTermRecursive(IListTerm list, final java.util.Set stack, final java.util.Map visited) { Deque elements = new ArrayDeque<>(); while(list != null) { - list = list.match(ListTerms.cases( - // @formatter:off - cons -> { + switch(list.listTermTag()) { + case IConsTerm: { IConsTerm cons = (IConsTerm) list; elements.push(cons); - return cons.getTail(); - }, - nil -> { - elements.push(nil); - return null; - }, - var -> { - elements.push(var); - return null; + list = cons.getTail(); + continue; + } + + case INilTerm: + case ITermVar: { + elements.push(list); + list = null; + continue; } - // @formatter:on - )); + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for IListTerm subclass/tag"); } Ref instance = new Ref<>(); while(!elements.isEmpty()) { - instance.set(elements.pop().match(ListTerms.cases( - // @formatter:off - cons -> B.newCons(findTermRecursive(cons.getHead(), stack, visited), instance.get(), cons.getAttachments()), - nil -> nil, - var -> (IListTerm) findVarRecursive(var, stack, visited) - // @formatter:on - ))); + IListTerm element = elements.pop(); + switch(element.listTermTag()) { + case IConsTerm: { IConsTerm cons = (IConsTerm) element; + instance.set(B.newCons(findTermRecursive(cons.getHead(), stack, visited), + instance.get(), cons.getAttachments())); + continue; + } + + case INilTerm: { INilTerm nil = (INilTerm) element; + instance.set(nil); + continue; + } + + case ITermVar: { ITermVar var = (ITermVar) element; + instance.set((IListTerm) findVarRecursive(var, stack, visited)); + continue; + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for IListTerm subclass/tag"); } return instance.get(); } @@ -176,7 +211,7 @@ private ITerm findVarRecursive(final ITermVar var, final java.util.Set private Iterable findRecursiveTerms(final Iterable terms, final java.util.Set stack, final java.util.Map visited) { - List instances = Lists.newArrayList(); + List instances = new ArrayList<>(); for(ITerm term : terms) { instances.add(findTermRecursive(term, stack, visited)); } @@ -306,38 +341,53 @@ private void getVars(final ITermVar var, final LinkedList stack, final private TermSize size(final ITerm term, final java.util.Set stack, final java.util.Map visited) { - return term.match(Terms.cases( - // @formatter:off - appl -> TermSize.ONE.add(sizes(appl.getArgs(), stack, visited)), - list -> size(list, stack, visited), - string -> TermSize.ONE, - integer -> TermSize.ONE, - blob -> TermSize.ONE, - var -> size(var, stack, visited) - // @formatter:on - )); + switch(term.termTag()) { + case IApplTerm: { IApplTerm appl = (IApplTerm) term; + return TermSize.ONE.add(sizes(appl.getArgs(), stack, visited)); + } + + case IConsTerm: + case INilTerm: + case ITermVar: { + return size((IListTerm) term, stack, visited); + } + + case IStringTerm: + case IIntTerm: + case IBlobTerm: { + return TermSize.ONE; + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for ITerm subclass/tag"); } private TermSize size(IListTerm list, final java.util.Set stack, final java.util.Map visited) { final Ref size = new Ref<>(TermSize.ZERO); while(list != null) { - list = list.match(ListTerms.cases( - // @formatter:off - cons -> { - size.set(size.get().add(TermSize.ONE).add(size(cons.getHead(), stack, visited))); - return cons.getTail(); - }, - nil -> { + switch(list.listTermTag()) { + case IConsTerm: { IConsTerm cons = (IConsTerm) list; + size.set(size.get().add(TermSize.ONE) + .add(size(cons.getHead(), stack, visited))); + list = cons.getTail(); + continue; + } + + case INilTerm: { size.set(size.get().add(TermSize.ONE)); - return null; - }, - var -> { + list = null; + continue; + } + + case ITermVar: { ITermVar var = (ITermVar) list; size.set(size.get().add(size(var, stack, visited))); - return null; + list = null; + continue; } - // @formatter:on - )); + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for IListTerm subclass/tag"); } return size.get(); } @@ -384,24 +434,48 @@ private TermSize sizes(final Iterable terms, final java.util.Set stack, - final java.util.Map visited, final int maxDepth, - final SpecializedTermFormatter specializedTermFormatter) { + final java.util.Map visited, final int maxDepth, + final SpecializedTermFormatter specializedTermFormatter) { if(maxDepth == 0) { return "…"; } - final PartialFunction1 tf = t -> specializedTermFormatter.formatSpecialized(term, this, st -> { - return toString(st, stack, visited, maxDepth - 1, specializedTermFormatter); - }); - // @formatter:off - return term.match(Terms.cases( - appl -> tf.apply(appl).orElseGet(() -> appl.getOp() + "(" + toStrings(appl.getArgs(), stack, visited, maxDepth - 1, specializedTermFormatter) + ")"), - list -> tf.apply(list).orElseGet(() -> toString(list, stack, visited, maxDepth, specializedTermFormatter)), - string -> tf.apply(string).orElseGet(() -> string.toString()), - integer -> tf.apply(integer).orElseGet(() -> integer.toString()), - blob -> tf.apply(blob).orElseGet(() -> blob.toString()), - var -> toString(var, stack, visited, maxDepth, specializedTermFormatter) - )); - // @formatter:on + if(term.termTag() == ITerm.Tag.ITermVar) { + ITermVar var = (ITermVar) term; + return toString(var, stack, visited, maxDepth, + specializedTermFormatter); + } + final Optional formatted = + specializedTermFormatter.formatSpecialized(term, this, + st -> toString(st, stack, visited, maxDepth - 1, specializedTermFormatter)); + switch(term.termTag()) { + case IApplTerm: { + IApplTerm appl = (IApplTerm) term; + return formatted.orElseGet( + () -> appl.getOp() + "(" + toStrings(appl.getArgs(), stack, visited, + maxDepth - 1, specializedTermFormatter) + ")"); + } + + case IConsTerm: + case INilTerm: { + IListTerm list = (IListTerm) term; + return formatted.orElseGet( + () -> toString(list, stack, visited, maxDepth, + specializedTermFormatter)); + } + + case IStringTerm: + case IIntTerm: + case IBlobTerm: { + return formatted.orElseGet(term::toString); + } + + case ITermVar: { + // impossible branch due to earlier if + return + break; + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for ITerm subclass/tag"); } private String toString(IListTerm list, final java.util.Map stack, @@ -416,30 +490,36 @@ private String toString(IListTerm list, final java.util.Map st sb.append("["); while(list != null) { if(remaining == 0) { - if(list.match(ListTerms.cases().nil(nil -> false).otherwise(l -> true))) { + if(list.listTermTag() != IListTerm.Tag.INilTerm) { sb.append("|…"); } break; } - list = list.match(ListTerms.cases( - // @formatter:off - cons -> { + switch(list.listTermTag()) { + case IConsTerm: { IConsTerm cons = (IConsTerm) list; if(tail.getAndSet(true)) { sb.append(","); } - sb.append(toString(cons.getHead(), stack, visited, maxDepth - 1, specializedTermFormatter)); - return cons.getTail(); - }, - nil -> { - return null; - }, - var -> { + sb.append( + toString(cons.getHead(), stack, visited, maxDepth - 1, + specializedTermFormatter)); + list = cons.getTail(); + break; + } + + case INilTerm: { + list = null; + break; + } + + case ITermVar: { ITermVar var = (ITermVar) list; sb.append("|"); - sb.append(toString(var, stack, visited, maxDepth - 1, specializedTermFormatter)); - return null; + sb.append(toString(var, stack, visited, maxDepth - 1, + specializedTermFormatter)); + list = null; + break; } - // @formatter:on - )); + } remaining--; } sb.append("]"); diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/unification/u/PersistentUnifier.java b/nabl2.terms/src/main/java/mb/nabl2/terms/unification/u/PersistentUnifier.java index 6797d0611..cfbcb7481 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/unification/u/PersistentUnifier.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/unification/u/PersistentUnifier.java @@ -2,6 +2,7 @@ import java.io.Serializable; import java.util.ArrayDeque; +import java.util.ArrayList; import java.util.Collection; import java.util.Deque; import java.util.Iterator; @@ -18,15 +19,18 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.LinkedListMultimap; import com.google.common.collect.ListMultimap; -import com.google.common.collect.Lists; import io.usethesource.capsule.Map; import io.usethesource.capsule.Set; +import mb.nabl2.terms.IApplTerm; +import mb.nabl2.terms.IBlobTerm; +import mb.nabl2.terms.IConsTerm; +import mb.nabl2.terms.IIntTerm; import mb.nabl2.terms.IListTerm; +import mb.nabl2.terms.INilTerm; +import mb.nabl2.terms.IStringTerm; import mb.nabl2.terms.ITerm; import mb.nabl2.terms.ITermVar; -import mb.nabl2.terms.ListTerms; -import mb.nabl2.terms.Terms; import mb.nabl2.terms.substitution.IRenaming; import mb.nabl2.terms.substitution.IReplacement; import mb.nabl2.terms.substitution.ISubstitution; @@ -145,7 +149,7 @@ private static class Unify extends PersistentUnifier.Transient { private final Predicate1 isRigid; private final Deque> worklist = new ArrayDeque<>(); - private final List result = Lists.newArrayList(); + private final List result = new ArrayList<>(); public Unify(PersistentUnifier.Immutable unifier, ITerm left, ITerm right, Predicate1 isRigid) { super(unifier); @@ -201,114 +205,202 @@ private void occursCheck(final PersistentUnifier.Immutable unifier) throws Occur } private boolean unifyTerms(final ITerm left, final ITerm right) throws RigidException { - // @formatter:off - return left.matchOrThrow(Terms.checkedCases( - applLeft -> right.matchOrThrow(Terms.checkedCases() - .appl(applRight -> { - return applLeft.getArity() == applRight.getArity() && - applLeft.getOp().equals(applRight.getOp()) && - unifys(applLeft.getArgs(), applRight.getArgs()); - }) - .var(varRight -> { - return unifyTerms(varRight, applLeft) ; - }) - .otherwise(t -> { - return false; - }) - ), - listLeft -> right.matchOrThrow(Terms.checkedCases() - .list(listRight -> { - return unifyLists(listLeft, listRight); - }) - .var(varRight -> { - return unifyTerms(varRight, listLeft); - }) - .otherwise(t -> { - return false; - }) - ), - stringLeft -> right.matchOrThrow(Terms.checkedCases() - .string(stringRight -> { - return stringLeft.getValue().equals(stringRight.getValue()); - }) - .var(varRight -> { - return unifyTerms(varRight, stringLeft); - }) - .otherwise(t -> { - return false; - }) - ), - integerLeft -> right.matchOrThrow(Terms.checkedCases() - .integer(integerRight -> { - return integerLeft.getValue() == integerRight.getValue(); - }) - .var(varRight -> { - return unifyTerms(varRight, integerLeft); - }) - .otherwise(t -> { - return false; - }) - ), - blobLeft -> right.matchOrThrow(Terms.checkedCases() - .blob(blobRight -> { - return blobLeft.getValue().equals(blobRight.getValue()); - }) - .var(varRight -> { - return unifyTerms(varRight, blobLeft); - }) - .otherwise(t -> { - return false; - }) - ), - varLeft -> right.matchOrThrow(Terms.checkedCases() - .var(varRight -> { - return unifyVars(varLeft, varRight); - }) - .otherwise(termRight -> { - return unifyVarTerm(varLeft, termRight); - }) - ) - )); - // @formatter:on - } - - private boolean unifyLists(final IListTerm left, final IListTerm right) throws RigidException { - // @formatter:off - return left.matchOrThrow(ListTerms.checkedCases( - consLeft -> right.matchOrThrow(ListTerms.checkedCases() - .cons(consRight -> { - worklist.push(Tuple2.of(consLeft.getHead(), consRight.getHead())); - worklist.push(Tuple2.of(consLeft.getTail(), consRight.getTail())); - return true; - }) - .var(varRight -> { - return unifyLists(varRight, consLeft); - }) - .otherwise(l -> { - return false; - }) - ), - nilLeft -> right.matchOrThrow(ListTerms.checkedCases() - .nil(nilRight -> { - return true; - }) - .var(varRight -> { - return unifyVarTerm(varRight, nilLeft) ; - }) - .otherwise(l -> { - return false; - }) - ), - varLeft -> right.matchOrThrow(ListTerms.checkedCases() - .var(varRight -> { - return unifyVars(varLeft, varRight); - }) - .otherwise(termRight -> { - return unifyVarTerm(varLeft, termRight); - }) - ) - )); - // @formatter:on + switch(left.termTag()) { + case IApplTerm: { IApplTerm applTerm = (IApplTerm) left; + switch(right.termTag()) { + case IApplTerm: { IApplTerm applRight = (IApplTerm) right; + return applTerm.getArity() == applRight.getArity() && + applTerm.getOp().equals(applRight.getOp()) && + unifys(applTerm.getArgs(), applRight.getArgs()); + } + case ITermVar: { ITermVar varRight = (ITermVar) right; + return unifyTerms(varRight, applTerm); + } + case IConsTerm: + case INilTerm: + case IStringTerm: + case IIntTerm: + case IBlobTerm: { + return false; + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for ITerm subclass/tag"); + } + + case IConsTerm: + case INilTerm: { IListTerm list = (IListTerm) left; + switch(right.termTag()) { + case IConsTerm: + case INilTerm: { IListTerm listRight = (IListTerm) right; + return unifyLists(list, listRight); + } + case ITermVar: { ITermVar varRight = (ITermVar) right; + return unifyTerms(varRight, list); + } + case IApplTerm: + case IStringTerm: + case IIntTerm: + case IBlobTerm: { + return false; + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for ITerm subclass/tag"); + } + + case IStringTerm: { IStringTerm string = (IStringTerm) left; + switch(right.termTag()) { + case IStringTerm: { IStringTerm stringRight = (IStringTerm) right; + return string.getValue().equals(stringRight.getValue()); + } + case ITermVar: { ITermVar varRight = (ITermVar) right; + return unifyTerms(varRight, string); + } + case IApplTerm: + case IConsTerm: + case INilTerm: + case IIntTerm: + case IBlobTerm: { + return false; + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for ITerm subclass/tag"); + } + + case IIntTerm: { IIntTerm integer = (IIntTerm) left; + switch(right.termTag()) { + case IIntTerm: { IIntTerm integerRight = (IIntTerm) right; + return integer.getValue() == integerRight.getValue(); + } + case ITermVar: { ITermVar varRight = (ITermVar) right; + return unifyTerms(varRight, integer); + } + case IApplTerm: + case IConsTerm: + case INilTerm: + case IStringTerm: + case IBlobTerm: { + return false; + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for ITerm subclass/tag"); + } + + case IBlobTerm: { IBlobTerm blob = (IBlobTerm) left; + switch(right.termTag()) { + case IBlobTerm: { IBlobTerm blobRight = (IBlobTerm) right; + return blob.getValue().equals(blobRight.getValue()); + } + case ITermVar: { ITermVar varRight = (ITermVar) right; + return unifyTerms(varRight, blob); + } + case IApplTerm: + case IConsTerm: + case INilTerm: + case IStringTerm: + case IIntTerm: { + return false; + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for ITerm subclass/tag"); + } + + case ITermVar: { ITermVar var = (ITermVar) left; + switch(right.termTag()) { + case IApplTerm: + case IConsTerm: + case INilTerm: + case IStringTerm: + case IIntTerm: + case IBlobTerm: { + return unifyVarTerm(var, right); + } + case ITermVar: { ITermVar varRight = (ITermVar) right; + return unifyVars(var, varRight); + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for ITerm subclass/tag"); + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for ITerm subclass/tag"); + } + + private boolean unifyLists(final IListTerm left, final IListTerm right) + throws RigidException { + switch(left.listTermTag()) { + case IConsTerm: { + IConsTerm cons = (IConsTerm) left; + switch(right.listTermTag()) { + case IConsTerm: { + IConsTerm consRight = (IConsTerm) right; + worklist.push(Tuple2.of(cons.getHead(), consRight.getHead())); + worklist.push(Tuple2.of(cons.getTail(), consRight.getTail())); + return true; + } + + case INilTerm: { + INilTerm varRight = (INilTerm) right; + return unifyLists(varRight, cons); + } + + case ITermVar: { + return false; + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for IListTerm subclass/tag"); + } + + case INilTerm: { + INilTerm nil = (INilTerm) left; + switch(right.listTermTag()) { + case IConsTerm: { + return false; + } + + case INilTerm: { + return true; + } + + case ITermVar: { + ITermVar varRight = (ITermVar) right; + return unifyVarTerm(varRight, nil); + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for IListTerm subclass/tag"); + } + + case ITermVar: { + ITermVar var = (ITermVar) left; + switch(right.listTermTag()) { + case IConsTerm: { + IConsTerm termRight = (IConsTerm) right; + return unifyVarTerm(var, termRight); + } + + case INilTerm: { + INilTerm termRight = (INilTerm) right; + return unifyVarTerm(var, termRight); + } + + case ITermVar: { + ITermVar varRight = (ITermVar) right; + return unifyVars(var, varRight); + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for IListTerm subclass/tag"); + } + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for IListTerm subclass/tag"); } private boolean unifyVarTerm(final ITermVar var, final ITerm term) throws RigidException { @@ -831,10 +923,9 @@ public PersistentUnifier.Immutable freeze() { && domainSetCache.isEmpty() && rangeSetCache.isEmpty() && varSetCache.isEmpty()) { return Immutable.of(finite); } else { - final PersistentUnifier.Immutable unifier = new PersistentUnifier.Immutable(finite, reps.freeze(), + return new Immutable(finite, reps.freeze(), ranks.freeze(), terms.freeze(), repAndTermVarsCache.freeze(), domainSetCache.freeze(), rangeSetCache.freeze(), varSetCache.freeze()); - return unifier; } } diff --git a/nabl2.terms/src/main/java/mb/nabl2/terms/unification/ud/PersistentUniDisunifier.java b/nabl2.terms/src/main/java/mb/nabl2/terms/unification/ud/PersistentUniDisunifier.java index 6cbb926e5..c0eea1d4a 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/terms/unification/ud/PersistentUniDisunifier.java +++ b/nabl2.terms/src/main/java/mb/nabl2/terms/unification/ud/PersistentUniDisunifier.java @@ -114,9 +114,7 @@ public static class Immutable extends PersistentUniDisunifier implements IUniDis if((r = unifier.unify(left, right, isRigid).orElse(null)) == null) { return Optional.empty(); } - return normalizeDiseqs(r.unifier(), disequalities).map(ud -> { - return new ImmutableResult<>(r.result(), ud); - }); + return normalizeDiseqs(r.unifier(), disequalities).map(ud -> new ImmutableResult<>(r.result(), ud)); } @Override public Optional> unify( @@ -126,9 +124,7 @@ public static class Immutable extends PersistentUniDisunifier implements IUniDis if((r = unifier.unify(equalities, isRigid).orElse(null)) == null) { return Optional.empty(); } - return normalizeDiseqs(r.unifier(), disequalities).map(ud -> { - return new ImmutableResult<>(r.result(), ud); - }); + return normalizeDiseqs(r.unifier(), disequalities).map(ud -> new ImmutableResult<>(r.result(), ud)); } @Override public Optional> unify(IUnifier other, @@ -137,9 +133,7 @@ public static class Immutable extends PersistentUniDisunifier implements IUniDis if((r = unifier.unify(other, isRigid).orElse(null)) == null) { return Optional.empty(); } - return normalizeDiseqs(r.unifier(), disequalities).map(ud -> { - return new ImmutableResult<>(r.result(), ud); - }); + return normalizeDiseqs(r.unifier(), disequalities).map(ud -> new ImmutableResult<>(r.result(), ud)); } @Override public Optional> uniDisunify(IUniDisunifier other, @@ -156,9 +150,7 @@ public static class Immutable extends PersistentUniDisunifier implements IUniDis unifier.addRangeVar(var, 1); } } - return normalizeDiseqs(unifier.freeze(), disequalities.__insertAll(other.disequalities())).map(ud -> { - return new ImmutableResult<>(r.result(), ud); - }); + return normalizeDiseqs(unifier.freeze(), disequalities.__insertAll(other.disequalities())).map(ud -> new ImmutableResult<>(r.result(), ud)); } /////////////////////////////////////////// diff --git a/nabl2.terms/src/main/java/mb/nabl2/util/collections/IndexedBagMultimap.java b/nabl2.terms/src/main/java/mb/nabl2/util/collections/IndexedBagMultimap.java index a9643964a..6909c32b6 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/util/collections/IndexedBagMultimap.java +++ b/nabl2.terms/src/main/java/mb/nabl2/util/collections/IndexedBagMultimap.java @@ -1,5 +1,6 @@ package mb.nabl2.util.collections; +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.Map; @@ -105,7 +106,7 @@ public Collection reindex(I index, Function1 reindexAll(Function1> normalize) { - return Lists.newArrayList(entries.keySet()).stream().flatMap(i -> reindex(i, normalize).stream()) + return new ArrayList<>(entries.keySet()).stream().flatMap(i -> reindex(i, normalize).stream()) .collect(ImmutableList.toImmutableList()); } diff --git a/nabl2.terms/src/main/java/mb/nabl2/util/graph/alg/counting/CountingTcRelation.java b/nabl2.terms/src/main/java/mb/nabl2/util/graph/alg/counting/CountingTcRelation.java index b32b7f534..bc536fd9e 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/util/graph/alg/counting/CountingTcRelation.java +++ b/nabl2.terms/src/main/java/mb/nabl2/util/graph/alg/counting/CountingTcRelation.java @@ -115,14 +115,10 @@ public boolean updateTuple(V source, V target, boolean isInsertion) { public void deleteTupleEnd(V deleted) { Set sourcesToDelete = CollectionsFactory.createSet(); Set targetsToDelete = CollectionsFactory.createSet(); - - for (V target : tuplesForward.lookupOrEmpty(deleted).distinctValues()) { - targetsToDelete.add(target); - } + + targetsToDelete.addAll(tuplesForward.lookupOrEmpty(deleted).distinctValues()); if (tuplesBackward != null) { - for (V source : tuplesBackward.lookupOrEmpty(deleted).distinctValues()) { - sourcesToDelete.add(source); - } + sourcesToDelete.addAll(tuplesBackward.lookupOrEmpty(deleted).distinctValues()); } else { for (V sourceCandidate : tuplesForward.distinctKeys()) { if (tuplesForward.lookupOrEmpty(sourceCandidate).containsNonZero(deleted)) diff --git a/nabl2.terms/src/main/java/mb/nabl2/util/graph/alg/misc/CollectionsFactory.java b/nabl2.terms/src/main/java/mb/nabl2/util/graph/alg/misc/CollectionsFactory.java index d211b420b..af4bfba13 100644 --- a/nabl2.terms/src/main/java/mb/nabl2/util/graph/alg/misc/CollectionsFactory.java +++ b/nabl2.terms/src/main/java/mb/nabl2/util/graph/alg/misc/CollectionsFactory.java @@ -8,6 +8,7 @@ *******************************************************************************/ package mb.nabl2.util.graph.alg.misc; +import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; @@ -57,7 +58,7 @@ public static Set createSet(Collection initial) { * @since 1.7 */ public static List createObserverList() { - return Lists.newArrayList(); + return new ArrayList<>(); } /** diff --git a/statix.generator/src/main/java/mb/statix/generator/strategy/SearchStrategies.java b/statix.generator/src/main/java/mb/statix/generator/strategy/SearchStrategies.java index d42b6d222..49a68de8c 100644 --- a/statix.generator/src/main/java/mb/statix/generator/strategy/SearchStrategies.java +++ b/statix.generator/src/main/java/mb/statix/generator/strategy/SearchStrategies.java @@ -1,5 +1,6 @@ package mb.statix.generator.strategy; +import java.util.EnumSet; import java.util.Map; import org.metaborg.util.functions.Action1; @@ -201,13 +202,15 @@ public static Require requi public static SearchStrategy mapPred(String pattern, Function1 f) { final mb.statix.generator.predicate.Match match = new mb.statix.generator.predicate.Match(pattern); - return map(Constraints.bottomup(Constraints.cases().user(c -> { - if(match.test(c)) { - return f.apply(c); - } else { - return c; + return map(Constraints.bottomup(constraint -> { + if(constraint.constraintTag() == IConstraint.Tag.CUser) { + CUser c = (CUser) constraint; + if(match.test(c)) { + return f.apply(c); + } } - }).otherwise(c -> c), false)); + return constraint; + }, false)); } public static SearchStrategy addAuxPred(String pattern, Function1 f) { @@ -216,12 +219,14 @@ public static SearchStrategy addAuxPred(String patter public static SearchStrategy dropPred(String pattern) { final mb.statix.generator.predicate.Match match = new mb.statix.generator.predicate.Match(pattern); - return filter(Constraints.cases().user(c -> !match.test(c)).otherwise(c -> true)::apply); + return filter(constraint -> constraint.constraintTag() != IConstraint.Tag.CUser || !match.test( + (CUser) constraint)); } + public static final EnumSet astTags = EnumSet.of(IConstraint.Tag.CAstId, IConstraint.Tag.CAstProperty); + public static SearchStrategy dropAst() { - return filter( - Constraints.cases().termId(c -> false).termProperty(c -> false).otherwise(c -> true)::apply); + return filter(constraint -> !astTags.contains(constraint.constraintTag())); } diff --git a/statix.solver/src/main/java/mb/statix/concurrent/StatixSolver.java b/statix.solver/src/main/java/mb/statix/concurrent/StatixSolver.java index 446e40bef..f7c508ff8 100644 --- a/statix.solver/src/main/java/mb/statix/concurrent/StatixSolver.java +++ b/statix.solver/src/main/java/mb/statix/concurrent/StatixSolver.java @@ -477,9 +477,8 @@ private boolean k(IConstraint constraint, int fuel) throws InterruptedException } // solve - return constraint.matchOrThrow(new IConstraint.CheckedCases() { - - @Override public Boolean caseArith(CArith c) throws InterruptedException { + switch(constraint.constraintTag()) { + case CArith: { CArith c = (CArith) constraint; final IUniDisunifier unifier = state.unifier(); final Optional term1 = c.expr1().isTerm(); final Optional term2 = c.expr2().isTerm(); @@ -509,11 +508,11 @@ private boolean k(IConstraint constraint, int fuel) throws InterruptedException } } - @Override public Boolean caseConj(CConj c) throws InterruptedException { + case CConj: { CConj c = (CConj) constraint; return success(c, state, NO_UPDATED_VARS, disjoin(c), NO_NEW_CRITICAL_EDGES, NO_EXISTENTIALS, fuel); } - @Override public Boolean caseEqual(CEqual c) throws InterruptedException { + case CEqual: { CEqual c = (CEqual) constraint; final ITerm term1 = c.term1(); final ITerm term2 = c.term2(); IUniDisunifier.Immutable unifier = state.unifier(); @@ -544,7 +543,7 @@ private boolean k(IConstraint constraint, int fuel) throws InterruptedException } } - @Override public Boolean caseExists(CExists c) throws InterruptedException { + case CExists: { CExists c = (CExists) constraint; final Renaming.Builder _existentials = Renaming.builder(); IState.Immutable newState = state; for(ITermVar var : c.vars()) { @@ -567,11 +566,11 @@ private boolean k(IConstraint constraint, int fuel) throws InterruptedException existentials.asMap(), fuel); } - @Override public Boolean caseFalse(CFalse c) throws InterruptedException { + case CFalse: { CFalse c = (CFalse) constraint; return fail(c); } - @Override public Boolean caseInequal(CInequal c) throws InterruptedException { + case CInequal: { CInequal c = (CInequal) constraint; final ITerm term1 = c.term1(); final ITerm term2 = c.term2(); final IUniDisunifier.Immutable unifier = state.unifier(); @@ -598,7 +597,7 @@ private boolean k(IConstraint constraint, int fuel) throws InterruptedException } } - @Override public Boolean caseNew(CNew c) throws InterruptedException { + case CNew: { CNew c = (CNew) constraint; final ITerm scopeTerm = c.scopeTerm(); final ITerm datumTerm = c.datumTerm(); final String name = M.var(ITermVar::getName).match(scopeTerm).orElse("s"); @@ -611,7 +610,7 @@ private boolean k(IConstraint constraint, int fuel) throws InterruptedException fuel); } - @Override public Boolean caseResolveQuery(IResolveQuery c) throws InterruptedException { + case IResolveQuery: { IResolveQuery c = (IResolveQuery) constraint; final QueryFilter filter = c.filter(); final QueryMin min = c.min(); final ITerm scopeTerm = c.scopeTerm(); @@ -644,31 +643,29 @@ private boolean k(IConstraint constraint, int fuel) throws InterruptedException final DataLeq dataEquivInternal = LOCAL_INFERENCE ? new ConstraintDataEquivInternal(dataLeqRule) : null; - final IFuture>> future; + IFuture>> future = null; if((flags & Solver.FORCE_INTERP_QUERIES) == 0) { - // @formatter:off - future = c.match(new IResolveQuery.Cases>>>() { - - @Override public IFuture>> caseResolveQuery(CResolveQuery q) { + switch(c.resolveQueryTag()) { + case CResolveQuery: { CResolveQuery q = (CResolveQuery) c; final LabelOrder labelOrder = new RelationLabelOrder<>(min.getLabelOrder()); - return scopeGraph.query(scope, labelWF, labelOrder, dataWF, dataEquiv, - dataWFInternal, dataEquivInternal); + future = scopeGraph.query(scope, labelWF, labelOrder, dataWF, dataEquiv, + dataWFInternal, dataEquivInternal); + break; } - @Override public IFuture>> caseCompiledQuery(CCompiledQuery q) { - return scopeGraph.query(scope, q.stateMachine(), labelWF, dataWF, dataEquiv, - dataWFInternal, dataEquivInternal); + case CCompiledQuery: { CCompiledQuery q = (CCompiledQuery) c; + future = scopeGraph.query(scope, q.stateMachine(), labelWF, dataWF, + dataEquiv, dataWFInternal, dataEquivInternal); + break; } - - }); - // @formatter:on + } } else { final LabelOrder labelOrder = new RelationLabelOrder<>(min.getLabelOrder()); future = scopeGraph.query(scope, labelWF, labelOrder, dataWF, dataEquiv, dataWFInternal, dataEquivInternal); } - final K>> k = (paths, ex, fuel) -> { + final K>> k = (paths, ex, fuel2) -> { if(ex != null) { // pattern matching for the brave and stupid try { @@ -696,13 +693,13 @@ private boolean k(IConstraint constraint, int fuel) throws InterruptedException .collect(ImmutableList.toImmutableList()); final IConstraint C = new CEqual(resultTerm, B.newList(pathTerms), c); return success(c, state, NO_UPDATED_VARS, ImmutableList.of(C), NO_NEW_CRITICAL_EDGES, - NO_EXISTENTIALS, fuel); + NO_EXISTENTIALS, fuel2); } }; return future(c, future, k); } - @Override public Boolean caseTellEdge(CTellEdge c) throws InterruptedException { + case CTellEdge: { CTellEdge c = (CTellEdge) constraint; final ITerm sourceTerm = c.sourceTerm(); final ITerm label = c.label(); final ITerm targetTerm = c.targetTerm(); @@ -724,7 +721,7 @@ private boolean k(IConstraint constraint, int fuel) throws InterruptedException fuel); } - @Override public Boolean caseTermId(CAstId c) throws InterruptedException { + case CAstId: { CAstId c = (CAstId) constraint; final ITerm term = c.astTerm(); final ITerm idTerm = c.idTerm(); @@ -752,7 +749,7 @@ private boolean k(IConstraint constraint, int fuel) throws InterruptedException } } - @Override public Boolean caseTermProperty(CAstProperty c) throws InterruptedException { + case CAstProperty: { CAstProperty c = (CAstProperty) constraint; final ITerm idTerm = c.idTerm(); final ITerm prop = c.property(); final ITerm value = c.value(); @@ -801,19 +798,19 @@ private boolean k(IConstraint constraint, int fuel) throws InterruptedException } } - @Override public Boolean caseTrue(CTrue c) throws InterruptedException { + case CTrue: { CTrue c = (CTrue) constraint; return success(c, state, NO_UPDATED_VARS, NO_NEW_CONSTRAINTS, NO_NEW_CRITICAL_EDGES, NO_EXISTENTIALS, fuel); } - @Override public Boolean caseTry(CTry c) throws InterruptedException { + case CTry: { CTry c = (CTry) constraint; final IDebugContext subDebug = debug.subContext(); final ITypeCheckerContext subContext = scopeGraph.subContext("try"); final IState.Immutable subState = state.subState().withResource(subContext.id()); final StatixSolver subSolver = new StatixSolver(c.constraint(), spec, subState, completeness, subDebug, progress, cancel, subContext, RETURN_ON_FIRST_ERROR); final IFuture subResult = subSolver.entail(); - final K k = (r, ex, fuel) -> { + final K k = (r, ex, fuel2) -> { if(ex != null) { debug.error("try {} failed", ex, c.toString(state.unifier()::toString)); return fail(c); @@ -826,7 +823,7 @@ private boolean k(IConstraint constraint, int fuel) throws InterruptedException debug.debug("constraint {} entailed", c.toString(state.unifier()::toString)); } return success(c, state, NO_UPDATED_VARS, NO_NEW_CONSTRAINTS, NO_NEW_CRITICAL_EDGES, - NO_EXISTENTIALS, fuel); + NO_EXISTENTIALS, fuel2); } else { if(debug.isEnabled(Level.Debug)) { debug.debug("constraint {} not entailed", c.toString(state.unifier()::toString)); @@ -841,7 +838,7 @@ private boolean k(IConstraint constraint, int fuel) throws InterruptedException return future(c, subResult, k); } - @Override public Boolean caseUser(CUser c) throws InterruptedException { + case CUser: { CUser c = (CUser) constraint; final String name = c.name(); final List args = c.args(); @@ -871,8 +868,10 @@ private boolean k(IConstraint constraint, int fuel) throws InterruptedException applyResult.criticalEdges(), NO_EXISTENTIALS, fuel); } - }); + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for IConstraint subclass/tag"); } /////////////////////////////////////////////////////////////////////////// diff --git a/statix.solver/src/main/java/mb/statix/concurrent/util/Patching.java b/statix.solver/src/main/java/mb/statix/concurrent/util/Patching.java index ff4f1384f..e478119e4 100644 --- a/statix.solver/src/main/java/mb/statix/concurrent/util/Patching.java +++ b/statix.solver/src/main/java/mb/statix/concurrent/util/Patching.java @@ -38,7 +38,6 @@ import mb.statix.constraints.messages.IMessage; import mb.statix.scopegraph.Scope; import mb.statix.solver.IConstraint; -import mb.statix.solver.IConstraint.Cases; import mb.statix.solver.completeness.ICompleteness; import mb.statix.solver.query.QueryFilter; import mb.statix.solver.query.QueryMin; @@ -62,37 +61,34 @@ public static Set.Immutable ruleScopes(Rule rule) { } public static Set.Immutable constraintScopes(IConstraint constraint) { - return constraint.match(new Cases>() { - - @Override public Immutable caseArith(CArith c) { + switch(constraint.constraintTag()) { + case CArith: + case CTrue: + case CFalse: { return CapsuleUtil.immutableSet(); } - - @Override public Immutable caseConj(CConj c) { + case CConj: { + CConj c = (CConj) constraint; return constraintScopes(c.left()).__insertAll(constraintScopes(c.right())); } - - @Override public Immutable caseEqual(CEqual c) { + case CEqual: { + CEqual c = (CEqual) constraint; return termScopes(c.term1()).__insertAll(termScopes(c.term2())); } - - @Override public Immutable caseExists(CExists c) { + case CExists: { + CExists c = (CExists) constraint; return constraintScopes(c.constraint()); } - - @Override public Immutable caseFalse(CFalse c) { - return CapsuleUtil.immutableSet(); - } - - @Override public Immutable caseInequal(CInequal c) { + case CInequal: { + CInequal c = (CInequal) constraint; return termScopes(c.term1()).__insertAll(termScopes(c.term2())); } - - @Override public Immutable caseNew(CNew c) { + case CNew: { + CNew c = (CNew) constraint; return termScopes(c.datumTerm()); } - - @Override public Immutable caseResolveQuery(IResolveQuery c) { + case IResolveQuery: { + IResolveQuery c = (IResolveQuery) constraint; final Set.Immutable scopeTermScopes = termScopes(c.scopeTerm()); final Set.Immutable resultTermScopes = termScopes(c.resultTerm()); @@ -100,34 +96,32 @@ public static Set.Immutable constraintScopes(IConstraint constraint) { final Set.Immutable dataEquivScopes = ruleScopes(c.min().getDataEquiv()); return scopeTermScopes.__insertAll(resultTermScopes).__insertAll(dataWfScopes) - .__insertAll(dataEquivScopes); + .__insertAll(dataEquivScopes); } - - @Override public Immutable caseTellEdge(CTellEdge c) { + case CTellEdge: { + CTellEdge c = (CTellEdge) constraint; return termScopes(c.sourceTerm()).__insertAll(termScopes(c.targetTerm())); } - - @Override public Immutable caseTermId(CAstId c) { + case CAstId: { + CAstId c = (CAstId) constraint; return termScopes(c.astTerm()).__insertAll(termScopes(c.idTerm())); } - - @Override public Immutable caseTermProperty(CAstProperty c) { + case CAstProperty: { + CAstProperty c = (CAstProperty) constraint; return termScopes(c.idTerm()).__insertAll(termScopes(c.value())); } - - @Override public Immutable caseTrue(CTrue c) { - return CapsuleUtil.immutableSet(); - } - - @Override public Immutable caseTry(CTry c) { + case CTry: { + CTry c = (CTry) constraint; return constraintScopes(c.constraint()); } - - @Override public Immutable caseUser(CUser c) { + case CUser: { + CUser c = (CUser) constraint; return c.args().stream().map(Patching::termScopes).flatMap(Set.Immutable::stream) - .collect(CapsuleCollectors.toSet()); + .collect(CapsuleCollectors.toSet()); } - }); + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for IConstraint subclass/tag"); } public static Set.Immutable termScopes(ITerm term) { @@ -153,36 +147,34 @@ public static Rule patch(Rule rule, IPatchCollection patches) { } public static IConstraint patch(IConstraint constraint, IPatchCollection patches) { - return constraint.match(new Cases() { + switch(constraint.constraintTag()) { - @Override public IConstraint caseArith(CArith c) { - return c; + case CArith: + case CFalse: + case CTrue: { + return constraint; } - @Override public IConstraint caseConj(CConj c) { + case CConj: { CConj c = (CConj) constraint; final IConstraint newLeft = patch(c.left(), patches); final IConstraint newRight = patch(c.right(), patches); return new CConj(newLeft, newRight, c.cause().orElse(null)); } - @Override public IConstraint caseEqual(CEqual c) { + case CEqual: { CEqual c = (CEqual) constraint; final ITerm newTerm1 = patch(c.term1(), patches); final ITerm newTerm2 = patch(c.term2(), patches); return new CEqual(newTerm1, newTerm2, c.cause().orElse(null), c.message().orElse(null)); } - @Override public IConstraint caseExists(CExists c) { + case CExists: { CExists c = (CExists) constraint; // TODO: preserve free vars? return c.withConstraint(patch(c.constraint(), patches)); } - @Override public IConstraint caseFalse(CFalse c) { - return c; - } - - @Override public IConstraint caseInequal(CInequal c) { + case CInequal: { CInequal c = (CInequal) constraint; final ITerm newTerm1 = patch(c.term1(), patches); final ITerm newTerm2 = patch(c.term2(), patches); @@ -192,14 +184,14 @@ public static IConstraint patch(IConstraint constraint, IPatchCollection return new CInequal(c.universals(), newTerm1, newTerm2, cause, message); } - @Override public IConstraint caseNew(CNew c) { + case CNew: { CNew c = (CNew) constraint; final ITerm newScopeTerm = patch(c.scopeTerm(), patches); final ITerm newDatumTerm = patch(c.datumTerm(), patches); return new CNew(newScopeTerm, newDatumTerm, c.cause().orElse(null), c.ownCriticalEdges().orElse(null)); } - @Override public IConstraint caseResolveQuery(IResolveQuery c) { + case IResolveQuery: { IResolveQuery c = (IResolveQuery) constraint; final ITerm newScopeTerm = patch(c.scopeTerm(), patches); final ITerm newResultTerm = patch(c.resultTerm(), patches); @@ -212,18 +204,20 @@ public static IConstraint patch(IConstraint constraint, IPatchCollection final @Nullable IConstraint cause = c.cause().orElse(null); final @Nullable IMessage message = c.message().orElse(null); - return c.match(new IResolveQuery.Cases() { - @Override public IResolveQuery caseResolveQuery(CResolveQuery q) { + switch(c.resolveQueryTag()) { + case CResolveQuery: { return new CResolveQuery(newFilter, newMin, newScopeTerm, newResultTerm, cause, message); } - @Override public IResolveQuery caseCompiledQuery(CCompiledQuery q) { + case CCompiledQuery: { CCompiledQuery q = (CCompiledQuery) c; return new CCompiledQuery(newFilter, newMin, newScopeTerm, newResultTerm, cause, message, q.stateMachine()); } - }); + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for IResolveQuery subclass/tag"); } - @Override public IConstraint caseTellEdge(CTellEdge c) { + case CTellEdge: { CTellEdge c = (CTellEdge) constraint; // TODO: patch critical edges? final ITerm newSourceTerm = patch(c.sourceTerm(), patches); final ITerm newTargetTerm = patch(c.targetTerm(), patches); @@ -234,36 +228,34 @@ public static IConstraint patch(IConstraint constraint, IPatchCollection return new CTellEdge(newSourceTerm, c.label(), newTargetTerm, cause, bodyCriticalEdges); } - @Override public IConstraint caseTermId(CAstId c) { + case CAstId: { CAstId c = (CAstId) constraint; final ITerm newAstTerm = patch(c.astTerm(), patches); final ITerm newIdTerm = patch(c.idTerm(), patches); return new CAstId(newAstTerm, newIdTerm, c.cause().orElse(null)); } - @Override public IConstraint caseTermProperty(CAstProperty c) { + case CAstProperty: { CAstProperty c = (CAstProperty) constraint; final ITerm newIdTerm = patch(c.idTerm(), patches); final ITerm newValue = patch(c.value(), patches); return new CAstProperty(newIdTerm, c.property(), c.op(), newValue, c.cause().orElse(null)); } - @Override public IConstraint caseTrue(CTrue c) { - return c; - } - - @Override public IConstraint caseTry(CTry c) { + case CTry: { CTry c = (CTry) constraint; final IConstraint newConstraint = patch(c.constraint(), patches); return new CTry(newConstraint, c.cause().orElse(null), c.message().orElse(null)); } - @Override public IConstraint caseUser(CUser c) { + case CUser: { CUser c = (CUser) constraint; final ImmutableList newArgs = c.args().stream().map(arg -> patch(arg, patches)).collect(ImmutableList.toImmutableList()); // TODO Patch ownCriticalEdges? return new CUser(c.name(), newArgs, c.cause().orElse(null), c.message().orElse(null), c.ownCriticalEdges().orElse(null)); } - }); + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for IConstraint subclass/tag"); } public static Pattern patch(Pattern pattern, IPatchCollection patches) { diff --git a/statix.solver/src/main/java/mb/statix/constraints/CArith.java b/statix.solver/src/main/java/mb/statix/constraints/CArith.java index 858b4d8d9..f15d37638 100644 --- a/statix.solver/src/main/java/mb/statix/constraints/CArith.java +++ b/statix.solver/src/main/java/mb/statix/constraints/CArith.java @@ -75,14 +75,6 @@ public ArithExpr expr2() { return new CArith(expr1, op, expr2, cause, message); } - @Override public R match(Cases cases) { - return cases.caseArith(this); - } - - @Override public R matchOrThrow(CheckedCases cases) throws E { - return cases.caseArith(this); - } - @Override public Set.Immutable getVars() { return Set.Immutable.union( expr1.getVars(), @@ -131,6 +123,10 @@ private void doVisitFreeVars(Action1 onFreeVar) { return sb.toString(); } + @Override public Tag constraintTag() { + return Tag.CArith; + } + @Override public String toString() { return toString(ITerm::toString); } diff --git a/statix.solver/src/main/java/mb/statix/constraints/CAstId.java b/statix.solver/src/main/java/mb/statix/constraints/CAstId.java index 1870fc651..af52fcb6d 100644 --- a/statix.solver/src/main/java/mb/statix/constraints/CAstId.java +++ b/statix.solver/src/main/java/mb/statix/constraints/CAstId.java @@ -51,14 +51,6 @@ public ITerm idTerm() { return new CAstId(term, idTerm, cause); } - @Override public R match(Cases cases) { - return cases.caseTermId(this); - } - - @Override public R matchOrThrow(CheckedCases cases) throws E { - return cases.caseTermId(this); - } - @Override public Set.Immutable getVars() { return Set.Immutable.union( term.getVars(), @@ -103,6 +95,10 @@ private void doVisitFreeVars(Action1 onFreeVar) { return sb.toString(); } + @Override public Tag constraintTag() { + return Tag.CAstId; + } + @Override public String toString() { return toString(ITerm::toString); } diff --git a/statix.solver/src/main/java/mb/statix/constraints/CAstProperty.java b/statix.solver/src/main/java/mb/statix/constraints/CAstProperty.java index d68e33646..f3bff2986 100644 --- a/statix.solver/src/main/java/mb/statix/constraints/CAstProperty.java +++ b/statix.solver/src/main/java/mb/statix/constraints/CAstProperty.java @@ -76,14 +76,6 @@ public ITerm value() { return new CAstProperty(idTerm, property, op, value, cause); } - @Override public R match(Cases cases) { - return cases.caseTermProperty(this); - } - - @Override public R matchOrThrow(CheckedCases cases) throws E { - return cases.caseTermProperty(this); - } - @Override public Set.Immutable getVars() { return Set.Immutable.union( idTerm.getVars(), @@ -129,6 +121,10 @@ private void doVisitFreeVars(Action1 onFreeVar) { return sb.toString(); } + @Override public Tag constraintTag() { + return Tag.CAstProperty; + } + @Override public String toString() { return toString(ITerm::toString); } diff --git a/statix.solver/src/main/java/mb/statix/constraints/CCompiledQuery.java b/statix.solver/src/main/java/mb/statix/constraints/CCompiledQuery.java index 87237aed3..545cc2bd9 100644 --- a/statix.solver/src/main/java/mb/statix/constraints/CCompiledQuery.java +++ b/statix.solver/src/main/java/mb/statix/constraints/CCompiledQuery.java @@ -42,15 +42,15 @@ public StateMachine stateMachine() { return stateMachine; } - @Override public R match(Cases cases) { - return cases.caseCompiledQuery(this); - } - @Override public R matchInResolution(ResolutionFunction1 onResolveQuery, ResolutionFunction1 onCompiledQuery) throws ResolutionException, InterruptedException { return onCompiledQuery.apply(this); } + @Override public Tag resolveQueryTag() { + return Tag.CCompiledQuery; + } + @Override public CCompiledQuery withCause(IConstraint cause) { return new CCompiledQuery(filter, min, scopeTerm, resultTerm, cause, message, stateMachine); } diff --git a/statix.solver/src/main/java/mb/statix/constraints/CConj.java b/statix.solver/src/main/java/mb/statix/constraints/CConj.java index ec19b275b..a81b483f4 100644 --- a/statix.solver/src/main/java/mb/statix/constraints/CConj.java +++ b/statix.solver/src/main/java/mb/statix/constraints/CConj.java @@ -51,14 +51,6 @@ public IConstraint right() { return new CConj(left, right, cause); } - @Override public R match(Cases cases) { - return cases.caseConj(this); - } - - @Override public R matchOrThrow(CheckedCases cases) throws E { - return cases.caseConj(this); - } - @Override public Set.Immutable getVars() { return Set.Immutable.union( left.getVars(), @@ -101,6 +93,10 @@ private void doVisitFreeVars(Action1 onFreeVar) { return sb.toString(); } + @Override public Tag constraintTag() { + return Tag.CConj; + } + @Override public String toString() { return toString(ITerm::toString); } diff --git a/statix.solver/src/main/java/mb/statix/constraints/CEqual.java b/statix.solver/src/main/java/mb/statix/constraints/CEqual.java index ce127cf31..205be56fb 100644 --- a/statix.solver/src/main/java/mb/statix/constraints/CEqual.java +++ b/statix.solver/src/main/java/mb/statix/constraints/CEqual.java @@ -70,14 +70,6 @@ public ITerm term2() { return new CEqual(term1, term2, cause, message); } - @Override public R match(Cases cases) { - return cases.caseEqual(this); - } - - @Override public R matchOrThrow(CheckedCases cases) throws E { - return cases.caseEqual(this); - } - @Override public Set.Immutable getVars() { return Set.Immutable.union( term1.getVars(), @@ -123,6 +115,10 @@ private void doVisitFreeVars(Action1 onFreeVar) { return sb.toString(); } + @Override public Tag constraintTag() { + return Tag.CEqual; + } + @Override public String toString() { return toString(ITerm::toString); } diff --git a/statix.solver/src/main/java/mb/statix/constraints/CExists.java b/statix.solver/src/main/java/mb/statix/constraints/CExists.java index 791f7152e..6a577e5ce 100644 --- a/statix.solver/src/main/java/mb/statix/constraints/CExists.java +++ b/statix.solver/src/main/java/mb/statix/constraints/CExists.java @@ -81,15 +81,6 @@ public CExists withConstraint(IConstraint constraint) { return new CExists(vars, constraint, cause, criticalEdges, freeVars); } - - @Override public R match(Cases cases) { - return cases.caseExists(this); - } - - @Override public R matchOrThrow(CheckedCases cases) throws E { - return cases.caseExists(this); - } - @Override public Set.Immutable getVars() { return Set.Immutable.union( vars, @@ -192,6 +183,10 @@ private void doVisitFreeVars(Action1 onFreeVar) { return sb.toString(); } + @Override public Tag constraintTag() { + return Tag.CExists; + } + @Override public String toString() { return toString(ITerm::toString); } diff --git a/statix.solver/src/main/java/mb/statix/constraints/CFalse.java b/statix.solver/src/main/java/mb/statix/constraints/CFalse.java index 8f1a6d4cc..fff369383 100644 --- a/statix.solver/src/main/java/mb/statix/constraints/CFalse.java +++ b/statix.solver/src/main/java/mb/statix/constraints/CFalse.java @@ -53,14 +53,6 @@ public CFalse(@Nullable IConstraint cause, @Nullable IMessage message) { return new CFalse(cause, message); } - @Override public R match(Cases cases) { - return cases.caseFalse(this); - } - - @Override public R matchOrThrow(CheckedCases cases) throws E { - return cases.caseFalse(this); - } - @Override public Set.Immutable getVars() { return Set.Immutable.of(); } @@ -97,6 +89,10 @@ private void doVisitFreeVars(Action1 onFreeVar) { return "false"; } + @Override public Tag constraintTag() { + return Tag.CFalse; + } + @Override public String toString() { return toString(ITerm::toString); } diff --git a/statix.solver/src/main/java/mb/statix/constraints/CInequal.java b/statix.solver/src/main/java/mb/statix/constraints/CInequal.java index 61a1f09ae..c43739ca5 100644 --- a/statix.solver/src/main/java/mb/statix/constraints/CInequal.java +++ b/statix.solver/src/main/java/mb/statix/constraints/CInequal.java @@ -75,14 +75,6 @@ public ITerm term2() { return new CInequal(universals, term1, term2, cause, message); } - @Override public R match(Cases cases) { - return cases.caseInequal(this); - } - - @Override public R matchOrThrow(CheckedCases cases) throws E { - return cases.caseInequal(this); - } - @Override public Set.Immutable getVars() { final Set.Transient vars = Set.Transient.of(); vars.__insertAll(universals); @@ -147,6 +139,10 @@ private void doVisitFreeVars(Action1 onFreeVar) { return sb.toString(); } + @Override public Tag constraintTag() { + return Tag.CInequal; + } + @Override public String toString() { return toString(ITerm::toString); } diff --git a/statix.solver/src/main/java/mb/statix/constraints/CNew.java b/statix.solver/src/main/java/mb/statix/constraints/CNew.java index ddd82a5c7..23e7fffce 100644 --- a/statix.solver/src/main/java/mb/statix/constraints/CNew.java +++ b/statix.solver/src/main/java/mb/statix/constraints/CNew.java @@ -47,14 +47,6 @@ public ITerm datumTerm() { return datumTerm; } - @Override public R match(Cases cases) { - return cases.caseNew(this); - } - - @Override public R matchOrThrow(CheckedCases cases) throws E { - return cases.caseNew(this); - } - @Override public Set.Immutable getVars() { return Set.Immutable.union( scopeTerm.getVars(), @@ -117,6 +109,10 @@ private void doVisitFreeVars(Action1 onFreeVar) { return sb.toString(); } + @Override public Tag constraintTag() { + return Tag.CNew; + } + @Override public String toString() { return toString(ITerm::toString); } diff --git a/statix.solver/src/main/java/mb/statix/constraints/CResolveQuery.java b/statix.solver/src/main/java/mb/statix/constraints/CResolveQuery.java index fd49f870f..51e13c335 100644 --- a/statix.solver/src/main/java/mb/statix/constraints/CResolveQuery.java +++ b/statix.solver/src/main/java/mb/statix/constraints/CResolveQuery.java @@ -33,15 +33,15 @@ public CResolveQuery(QueryFilter filter, QueryMin min, ITerm scopeTerm, ITerm re super(filter, min, scopeTerm, resultTerm, cause, message); } - @Override public R match(Cases cases) { - return cases.caseResolveQuery(this); - } - @Override public R matchInResolution(ResolutionFunction1 onResolveQuery, ResolutionFunction1 onCompiledQuery) throws ResolutionException, InterruptedException { return onResolveQuery.apply(this); } + @Override public Tag resolveQueryTag() { + return Tag.CResolveQuery; + } + @Override public CResolveQuery withCause(@Nullable IConstraint cause) { return new CResolveQuery(filter, min, scopeTerm, resultTerm, cause, message); } diff --git a/statix.solver/src/main/java/mb/statix/constraints/CTellEdge.java b/statix.solver/src/main/java/mb/statix/constraints/CTellEdge.java index e35ef259e..7bde91af8 100644 --- a/statix.solver/src/main/java/mb/statix/constraints/CTellEdge.java +++ b/statix.solver/src/main/java/mb/statix/constraints/CTellEdge.java @@ -69,14 +69,6 @@ public ITerm targetTerm() { return new CTellEdge(sourceTerm, label, targetTerm, cause, criticalEdges); } - @Override public R match(Cases cases) { - return cases.caseTellEdge(this); - } - - @Override public R matchOrThrow(CheckedCases cases) throws E { - return cases.caseTellEdge(this); - } - @Override public Set.Immutable getVars() { return Set.Immutable.union( sourceTerm.getVars(), @@ -124,6 +116,10 @@ private void doVisitFreeVars(Action1 onFreeVar) { return sb.toString(); } + @Override public Tag constraintTag() { + return Tag.CTellEdge; + } + @Override public String toString() { return toString(ITerm::toString); } diff --git a/statix.solver/src/main/java/mb/statix/constraints/CTrue.java b/statix.solver/src/main/java/mb/statix/constraints/CTrue.java index 57acd60a9..4fe21b351 100644 --- a/statix.solver/src/main/java/mb/statix/constraints/CTrue.java +++ b/statix.solver/src/main/java/mb/statix/constraints/CTrue.java @@ -38,14 +38,6 @@ public CTrue(@Nullable IConstraint cause) { return new CTrue(cause); } - @Override public R match(Cases cases) { - return cases.caseTrue(this); - } - - @Override public R matchOrThrow(CheckedCases cases) throws E { - return cases.caseTrue(this); - } - @Override public Set.Immutable getVars() { return Set.Immutable.of(); } @@ -79,6 +71,10 @@ private void doVisitFreeVars(@SuppressWarnings("unused") Action1 onFre return "true"; } + @Override public Tag constraintTag() { + return Tag.CTrue; + } + @Override public String toString() { return toString(ITerm::toString); } diff --git a/statix.solver/src/main/java/mb/statix/constraints/CTry.java b/statix.solver/src/main/java/mb/statix/constraints/CTry.java index b837a3a1e..c5af413a2 100644 --- a/statix.solver/src/main/java/mb/statix/constraints/CTry.java +++ b/statix.solver/src/main/java/mb/statix/constraints/CTry.java @@ -60,14 +60,6 @@ public IConstraint constraint() { return new CTry(constraint, cause, message); } - @Override public R match(Cases cases) { - return cases.caseTry(this); - } - - @Override public R matchOrThrow(CheckedCases cases) throws E { - return cases.caseTry(this); - } - @Override public Set.Immutable getVars() { return constraint.getVars(); } @@ -109,6 +101,10 @@ private void doVisitFreeVars(Action1 onFreeVar) { return sb.toString(); } + @Override public Tag constraintTag() { + return Tag.CTry; + } + @Override public String toString() { return toString(ITerm::toString); } diff --git a/statix.solver/src/main/java/mb/statix/constraints/CUser.java b/statix.solver/src/main/java/mb/statix/constraints/CUser.java index 7b9d51e3d..694366713 100644 --- a/statix.solver/src/main/java/mb/statix/constraints/CUser.java +++ b/statix.solver/src/main/java/mb/statix/constraints/CUser.java @@ -81,14 +81,6 @@ public List args() { return new CUser(name, args, cause, message, criticalEdges); } - @Override public R match(Cases cases) { - return cases.caseUser(this); - } - - @Override public R matchOrThrow(CheckedCases cases) throws E { - return cases.caseUser(this); - } - @Override public Set.Immutable getVars() { final Set.Transient vars = Set.Transient.of(); for(ITerm a : args) { @@ -138,6 +130,10 @@ private void doVisitFreeVars(Action1 onFreeVar) { return sb.toString(); } + @Override public Tag constraintTag() { + return Tag.CUser; + } + @Override public String toString() { return toString(ITerm::toString); } diff --git a/statix.solver/src/main/java/mb/statix/constraints/Constraints.java b/statix.solver/src/main/java/mb/statix/constraints/Constraints.java index 211b9a312..0ff543c1d 100644 --- a/statix.solver/src/main/java/mb/statix/constraints/Constraints.java +++ b/statix.solver/src/main/java/mb/statix/constraints/Constraints.java @@ -30,344 +30,45 @@ public final class Constraints { private Constraints() { } - // @formatter:off - public static IConstraint.Cases cases( - Function1 onArith, - Function1 onConj, - Function1 onEqual, - Function1 onExists, - Function1 onFalse, - Function1 onInequal, - Function1 onNew, - Function1 onResolveQuery, - Function1 onTellEdge, - Function1 onTermId, - Function1 onTermProperty, - Function1 onTrue, - Function1 onTry, - Function1 onUser - ) { - return new IConstraint.Cases() { - - @Override public R caseArith(CArith c) { - return onArith.apply(c); - } - - @Override public R caseConj(CConj c) { - return onConj.apply(c); - } - - @Override public R caseEqual(CEqual c) { - return onEqual.apply(c); - } - - @Override public R caseExists(CExists c) { - return onExists.apply(c); - } - - @Override public R caseFalse(CFalse c) { - return onFalse.apply(c); - } - - @Override public R caseInequal(CInequal c) { - return onInequal.apply(c); - } - - @Override public R caseNew(CNew c) { - return onNew.apply(c); - } - - @Override public R caseResolveQuery(IResolveQuery c) { - return onResolveQuery.apply(c); - } - - @Override public R caseTellEdge(CTellEdge c) { - return onTellEdge.apply(c); - } - - @Override public R caseTermId(CAstId c) { - return onTermId.apply(c); - } - - @Override public R caseTermProperty(CAstProperty c) { - return onTermProperty.apply(c); - } - - @Override public R caseTrue(CTrue c) { - return onTrue.apply(c); - } - - @Override public R caseTry(CTry c) { - return onTry.apply(c); - } - - @Override public R caseUser(CUser c) { - return onUser.apply(c); - } - - }; - } - // @formatter:on - - public static CaseBuilder cases() { - return new CaseBuilder<>(); - } - - public static class CaseBuilder { - - private Function1 onArith; - private Function1 onConj; - private Function1 onEqual; - private Function1 onExists; - private Function1 onFalse; - private Function1 onInequal; - private Function1 onNew; - private Function1 onResolveQuery; - private Function1 onTellEdge; - private Function1 onTermId; - private Function1 onTermProperty; - private Function1 onTrue; - private Function1 onTry; - private Function1 onUser; - - public CaseBuilder arith(Function1 onArith) { - this.onArith = onArith; - return this; - } - - public CaseBuilder conj(Function1 onConj) { - this.onConj = onConj; - return this; - } - - public CaseBuilder equal(Function1 onEqual) { - this.onEqual = onEqual; - return this; - } - - public CaseBuilder exists(Function1 onExists) { - this.onExists = onExists; - return this; - } - - public CaseBuilder _false(Function1 onFalse) { - this.onFalse = onFalse; - return this; - } - - public CaseBuilder inequal(Function1 onInequal) { - this.onInequal = onInequal; - return this; - } - - public CaseBuilder _new(Function1 onNew) { - this.onNew = onNew; - return this; - } - - public CaseBuilder resolveQuery(Function1 onResolveQuery) { - this.onResolveQuery = onResolveQuery; - return this; - } - - public CaseBuilder tellEdge(Function1 onTellEdge) { - this.onTellEdge = onTellEdge; - return this; - } - - public CaseBuilder termId(Function1 onTermId) { - this.onTermId = onTermId; - return this; - } - - public CaseBuilder termProperty(Function1 onTermProperty) { - this.onTermProperty = onTermProperty; - return this; - } - - public CaseBuilder _true(Function1 onTrue) { - this.onTrue = onTrue; - return this; - } - - public CaseBuilder _try(Function1 onTry) { - this.onTry = onTry; - return this; - } - - public CaseBuilder user(Function1 onUser) { - this.onUser = onUser; - return this; - } - - public IConstraint.Cases otherwise(Function1 otherwise) { - return new IConstraint.Cases() { - - @Override public R caseArith(CArith c) { - return onArith != null ? onArith.apply(c) : otherwise.apply(c); - } - - @Override public R caseConj(CConj c) { - return onConj != null ? onConj.apply(c) : otherwise.apply(c); - } - - @Override public R caseEqual(CEqual c) { - return onEqual != null ? onEqual.apply(c) : otherwise.apply(c); - } - - @Override public R caseExists(CExists c) { - return onExists != null ? onExists.apply(c) : otherwise.apply(c); - } - - @Override public R caseFalse(CFalse c) { - return onFalse != null ? onFalse.apply(c) : otherwise.apply(c); - } - - @Override public R caseInequal(CInequal c) { - return onInequal != null ? onInequal.apply(c) : otherwise.apply(c); - } - - @Override public R caseNew(CNew c) { - return onNew != null ? onNew.apply(c) : otherwise.apply(c); - } - - @Override public R caseResolveQuery(IResolveQuery c) { - return onResolveQuery != null ? onResolveQuery.apply(c) : otherwise.apply(c); - } - - @Override public R caseTellEdge(CTellEdge c) { - return onTellEdge != null ? onTellEdge.apply(c) : otherwise.apply(c); - } - - @Override public R caseTermId(CAstId c) { - return onTermId != null ? onTermId.apply(c) : otherwise.apply(c); - } - - @Override public R caseTermProperty(CAstProperty c) { - return onTermProperty != null ? onTermProperty.apply(c) : otherwise.apply(c); - } - - @Override public R caseTrue(CTrue c) { - return onTrue != null ? onTrue.apply(c) : otherwise.apply(c); - } - - @Override public R caseTry(CTry c) { - return onTry != null ? onTry.apply(c) : otherwise.apply(c); - } - - @Override public R caseUser(CUser c) { - return onUser != null ? onUser.apply(c) : otherwise.apply(c); - } - - }; - } - - } - - - // @formatter:off - public static IConstraint.CheckedCases checkedCases( - CheckedFunction1 onArith, - CheckedFunction1 onConj, - CheckedFunction1 onEqual, - CheckedFunction1 onExists, - CheckedFunction1 onFalse, - CheckedFunction1 onInequal, - CheckedFunction1 onNew, - CheckedFunction1 onResolveQuery, - CheckedFunction1 onTellEdge, - CheckedFunction1 onTermId, - CheckedFunction1 onTermProperty, - CheckedFunction1 onTrue, - CheckedFunction1 onTry, - CheckedFunction1 onUser - ) { - return new IConstraint.CheckedCases() { - - @Override public R caseArith(CArith c) throws E { - return onArith.apply(c); - } - - @Override public R caseConj(CConj c) throws E { - return onConj.apply(c); - } - - @Override public R caseEqual(CEqual c) throws E { - return onEqual.apply(c); - } - - @Override public R caseExists(CExists c) throws E { - return onExists.apply(c); - } - - @Override public R caseFalse(CFalse c) throws E { - return onFalse.apply(c); - } - - @Override public R caseInequal(CInequal c) throws E { - return onInequal.apply(c); - } - - @Override public R caseNew(CNew c) throws E { - return onNew.apply(c); - } - - @Override public R caseResolveQuery(IResolveQuery c) throws E { - return onResolveQuery.apply(c); - } - - @Override public R caseTellEdge(CTellEdge c) throws E { - return onTellEdge.apply(c); - } - - @Override public R caseTermId(CAstId c) throws E { - return onTermId.apply(c); - } - - @Override public R caseTermProperty(CAstProperty c) throws E { - return onTermProperty.apply(c); - } - - @Override public R caseTrue(CTrue c) throws E { - return onTrue.apply(c); - } - - @Override public R caseTry(CTry c) throws E { - return onTry.apply(c); - } - - @Override public R caseUser(CUser c) throws E { - return onUser.apply(c); - } - - }; - } - // @formatter:on - /** * Bottom up transformation, where the transformation is applied starting from the leaves, then to the transformed * parents until the root. */ public static Function1 bottomup(Function1 f, boolean recurseInLogicalScopes) { - // @formatter:off - return cases( - c -> f.apply(c), - c -> f.apply(new CConj(bottomup(f, recurseInLogicalScopes).apply(c.left()), bottomup(f, recurseInLogicalScopes).apply(c.right()), c.cause().orElse(null))), - c -> f.apply(c), - c -> f.apply(new CExists(c.vars(), bottomup(f, recurseInLogicalScopes).apply(c.constraint()), c.cause().orElse(null))), - c -> f.apply(c), - c -> f.apply(c), - c -> f.apply(c), - c -> f.apply(c), - c -> f.apply(c), - c -> f.apply(c), - c -> f.apply(c), - c -> f.apply(c), - c -> f.apply(recurseInLogicalScopes ? new CTry(bottomup(f, recurseInLogicalScopes).apply(c.constraint()), c.cause().orElse(null), c.message().orElse(null)) : c), - c -> f.apply(c) - ); - // @formatter:on + return constraint -> { + switch(constraint.constraintTag()) { + case CConj: { + CConj c = (CConj) constraint; + return f.apply(new CConj(bottomup(f, recurseInLogicalScopes).apply(c.left()), + bottomup(f, recurseInLogicalScopes).apply(c.right()), c.cause().orElse(null))); + } + case CExists: { + CExists c = (CExists) constraint; + return f.apply(new CExists(c.vars(), bottomup(f, recurseInLogicalScopes).apply(c.constraint()), + c.cause().orElse(null))); + } + case CTry: { + CTry c = (CTry) constraint; + return f.apply(recurseInLogicalScopes ? new CTry(bottomup(f, recurseInLogicalScopes).apply(c.constraint()), + c.cause().orElse(null), c.message().orElse(null)) : c); + } + case CArith: + case CEqual: + case CFalse: + case CInequal: + case CNew: + case IResolveQuery: + case CTellEdge: + case CAstId: + case CAstProperty: + case CTrue: + case CUser: + return f.apply(constraint); + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for IConstraint subclass/tag"); + }; } /** @@ -375,38 +76,44 @@ public static Function1 bottomup(Function1 map(Function1 f, boolean recurseInLogicalScopes) { - // @formatter:off - return cases( - c -> f.apply(c), - c -> { - final IConstraint left = map(f, recurseInLogicalScopes).apply(c.left()); - final IConstraint right = map(f, recurseInLogicalScopes).apply(c.right()); - return new CConj(left, right, c.cause().orElse(null)); - }, - c -> f.apply(c), - c -> { - final IConstraint body = map(f, recurseInLogicalScopes).apply(c.constraint()); - return new CExists(c.vars(), body, c.cause().orElse(null)); - }, - c -> f.apply(c), - c -> f.apply(c), - c -> f.apply(c), - c -> f.apply(c), - c -> f.apply(c), - c -> f.apply(c), - c -> f.apply(c), - c -> f.apply(c), - c -> { - if(recurseInLogicalScopes) { + return constraint -> { + switch(constraint.constraintTag()) { + case CConj: { + CConj c = (CConj) constraint; + final IConstraint left = map(f, recurseInLogicalScopes).apply(c.left()); + final IConstraint right = map(f, recurseInLogicalScopes).apply(c.right()); + return new CConj(left, right, c.cause().orElse(null)); + } + case CExists: { + CExists c = (CExists) constraint; final IConstraint body = map(f, recurseInLogicalScopes).apply(c.constraint()); - return new CTry(body, c.cause().orElse(null), c.message().orElse(null)); - } else { - return c; + return new CExists(c.vars(), body, c.cause().orElse(null)); + } + case CTry: { + CTry c = (CTry) constraint; + if(recurseInLogicalScopes) { + final IConstraint body = map(f, recurseInLogicalScopes).apply(c.constraint()); + return new CTry(body, c.cause().orElse(null), c.message().orElse(null)); + } else { + return c; + } } - }, - c -> f.apply(c) - ); - // @formatter:on + case CArith: + case CEqual: + case CFalse: + case CInequal: + case CNew: + case IResolveQuery: + case CTellEdge: + case CAstId: + case CAstProperty: + case CTrue: + case CUser: + return f.apply(constraint); + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for IConstraint subclass/tag"); + }; } /** @@ -414,38 +121,44 @@ public static Function1 map(Function1> filter(Function1> f, boolean recurseInLogicalScopes) { - // @formatter:off - return cases( - c -> f.apply(c), - c -> { - final Optional left = filter(f, recurseInLogicalScopes).apply(c.left()); - final Optional right = filter(f, recurseInLogicalScopes).apply(c.right()); - return Optionals.lift(left, right, (l, r) -> new CConj(l, r, c.cause().orElse(null))); - }, - c -> f.apply(c), - c -> { - final Optional body = filter(f, recurseInLogicalScopes).apply(c.constraint()); - return body.map(b -> new CExists(c.vars(), b, c.cause().orElse(null))); - }, - c -> f.apply(c), - c -> f.apply(c), - c -> f.apply(c), - c -> f.apply(c), - c -> f.apply(c), - c -> f.apply(c), - c -> f.apply(c), - c -> f.apply(c), - c -> { - if(recurseInLogicalScopes) { + return constraint -> { + switch(constraint.constraintTag()) { + case CConj: { + CConj c = (CConj) constraint; + final Optional left = filter(f, recurseInLogicalScopes).apply(c.left()); + final Optional right = filter(f, recurseInLogicalScopes).apply(c.right()); + return Optionals.lift(left, right, (l, r) -> new CConj(l, r, c.cause().orElse(null))); + } + case CExists: { + CExists c = (CExists) constraint; final Optional body = filter(f, recurseInLogicalScopes).apply(c.constraint()); - return body.map(b -> new CTry(b, c.cause().orElse(null), c.message().orElse(null))); - } else { - return Optional.of(c); + return body.map(b -> new CExists(c.vars(), b, c.cause().orElse(null))); + } + case CTry: { + CTry c = (CTry) constraint; + if(recurseInLogicalScopes) { + final Optional body = filter(f, recurseInLogicalScopes).apply(c.constraint()); + return body.map(b -> new CTry(b, c.cause().orElse(null), c.message().orElse(null))); + } else { + return Optional.of(c); + } } - }, - c -> f.apply(c) - ); - // @formatter:on + case CArith: + case CEqual: + case CFalse: + case CInequal: + case CNew: + case IResolveQuery: + case CTellEdge: + case CAstId: + case CAstProperty: + case CTrue: + case CUser: + return f.apply(constraint); + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for IConstraint subclass/tag"); + }; } /** @@ -453,42 +166,48 @@ public static Function1> filter(Function1> flatMap(Function1> f, boolean recurseInLogicalScopes) { - // @formatter:off - return cases( - c -> f.apply(c), - c -> { - return flatMap(f, recurseInLogicalScopes).apply(c.left()).flatMap(l -> { - return flatMap(f, recurseInLogicalScopes).apply(c.right()).map(r -> { - return new CConj(l, r, c.cause().orElse(null)); + return constraint -> { + switch(constraint.constraintTag()) { + case CConj: { + CConj c = (CConj) constraint; + return flatMap(f, recurseInLogicalScopes).apply(c.left()).flatMap(l -> { + return flatMap(f, recurseInLogicalScopes).apply(c.right()).map(r -> { + return new CConj(l, r, c.cause().orElse(null)); + }); }); - }); - }, - c -> f.apply(c), - c -> { - return flatMap(f, recurseInLogicalScopes).apply(c.constraint()).map(b -> { - return new CExists(c.vars(), b, c.cause().orElse(null)); - }); - }, - c -> f.apply(c), - c -> f.apply(c), - c -> f.apply(c), - c -> f.apply(c), - c -> f.apply(c), - c -> f.apply(c), - c -> f.apply(c), - c -> f.apply(c), - c -> { - if(recurseInLogicalScopes) { + } + case CExists: { + CExists c = (CExists) constraint; return flatMap(f, recurseInLogicalScopes).apply(c.constraint()).map(b -> { - return new CTry(b, c.cause().orElse(null), c.message().orElse(null)); + return new CExists(c.vars(), b, c.cause().orElse(null)); }); - } else { - return Stream.of(c); } - }, - c -> f.apply(c) - ); - // @formatter:on + case CTry: { + CTry c = (CTry) constraint; + if(recurseInLogicalScopes) { + return flatMap(f, recurseInLogicalScopes).apply(c.constraint()).map(b -> { + return new CTry(b, c.cause().orElse(null), c.message().orElse(null)); + }); + } else { + return Stream.of(c); + } + } + case CArith: + case CEqual: + case CFalse: + case CInequal: + case CNew: + case IResolveQuery: + case CTellEdge: + case CAstId: + case CAstProperty: + case CTrue: + case CUser: + return f.apply(constraint); + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for IConstraint subclass/tag"); + }; } public static Function1> collectBase(PartialFunction1 f, @@ -501,25 +220,42 @@ public static Function1> collectBase(PartialFunction1 void collectBase(IConstraint constraint, PartialFunction1 f, - ImmutableList.Builder ts, boolean recurseInLogicalScopes) { - // @formatter:off - constraint.match(cases( - c -> { f.apply(c).ifPresent(ts::add); return Unit.unit; }, - c -> { disjoin(c).forEach(cc -> collectBase(cc, f, ts, recurseInLogicalScopes)); return Unit.unit; }, - c -> { f.apply(c).ifPresent(ts::add); return Unit.unit; }, - c -> { disjoin(c.constraint()).forEach(cc -> collectBase(cc, f, ts, recurseInLogicalScopes)); return Unit.unit; }, - c -> { f.apply(c).ifPresent(ts::add); return Unit.unit; }, - c -> { f.apply(c).ifPresent(ts::add); return Unit.unit; }, - c -> { f.apply(c).ifPresent(ts::add); return Unit.unit; }, - c -> { f.apply(c).ifPresent(ts::add); return Unit.unit; }, - c -> { f.apply(c).ifPresent(ts::add); return Unit.unit; }, - c -> { f.apply(c).ifPresent(ts::add); return Unit.unit; }, - c -> { f.apply(c).ifPresent(ts::add); return Unit.unit; }, - c -> { f.apply(c).ifPresent(ts::add); return Unit.unit; }, - c -> { if(recurseInLogicalScopes) { disjoin(c.constraint()).forEach(cc -> collectBase(cc, f, ts, recurseInLogicalScopes)); } return Unit.unit; }, - c -> { f.apply(c).ifPresent(ts::add); return Unit.unit; } - )); - // @formatter:on + ImmutableList.Builder ts, boolean recurseInLogicalScopes) { + switch(constraint.constraintTag()) { + case CConj: { + CConj c = (CConj) constraint; + disjoin(c).forEach(cc -> collectBase(cc, f, ts, recurseInLogicalScopes)); + break; + } + case CExists: { + CExists c = (CExists) constraint; + disjoin(c.constraint()).forEach( + cc -> collectBase(cc, f, ts, recurseInLogicalScopes)); + break; + } + case CTry: { + CTry c = (CTry) constraint; + if(recurseInLogicalScopes) { + disjoin(c.constraint()).forEach( + cc -> collectBase(cc, f, ts, recurseInLogicalScopes)); + } + break; + } + case CArith: + case CEqual: + case CFalse: + case CInequal: + case CNew: + case IResolveQuery: + case CTellEdge: + case CAstId: + case CAstProperty: + case CTrue: + case CUser: { + f.apply(constraint).ifPresent(ts::add); + break; + } + } } public static List apply(List constraints, ISubstitution.Immutable subst) { @@ -583,18 +319,18 @@ public static void disjoin(IConstraint constraint, Action1 action) Deque worklist = Lists.newLinkedList(); worklist.push(constraint); while(!worklist.isEmpty()) { - worklist.pop().match(Constraints.cases().conj(conj -> { + IConstraint c = worklist.pop(); + if(c.constraintTag() == IConstraint.Tag.CConj) { + CConj conj = (CConj) c; // HEURISTIC Use the cause of the surrounding conjunction, or keep // the cause of the constraints. This is a heuristic which seems // to work well, but in general maintaining causes requires some // care throughout the solver code. worklist.push(conj.left().withCause(conj.cause().orElse(conj.left().cause().orElse(null)))); worklist.push(conj.right().withCause(conj.cause().orElse(conj.right().cause().orElse(null)))); - return Unit.unit; - }).otherwise(c -> { + } else { action.apply(c); - return Unit.unit; - })); + } } } @@ -615,74 +351,82 @@ public static Set.Immutable vars(IConstraint constraint) { } public static void vars(IConstraint constraint, Action1 onVar) { - // @formatter:off - constraint.match(Constraints.cases( - onArith -> { + switch(constraint.constraintTag()) { + case CArith: { + CArith onArith = (CArith) constraint; onArith.expr1().isTerm().ifPresent(t -> t.getVars().forEach(onVar::apply)); onArith.expr2().isTerm().ifPresent(t -> t.getVars().forEach(onVar::apply)); - return Unit.unit; - }, - onConj -> { + break; + } + case CConj: { + CConj onConj = (CConj) constraint; Constraints.disjoin(onConj).forEach(c -> vars(c, onVar)); - return Unit.unit; - }, - onEqual -> { + break; + } + case CEqual: { + CEqual onEqual = (CEqual) constraint; onEqual.term1().getVars().forEach(onVar::apply); onEqual.term2().getVars().forEach(onVar::apply); - return Unit.unit; - }, - onExists -> { + break; + } + case CExists: { + CExists onExists = (CExists) constraint; onExists.vars().forEach(onVar::apply); vars(onExists.constraint(), onVar); - return Unit.unit; - }, - onFalse -> { - return Unit.unit; - }, - onInequal -> { + break; + } + case CInequal: { + CInequal onInequal = (CInequal) constraint; onInequal.term1().getVars().stream().filter(v -> !onInequal.universals().contains(v)).forEach(onVar::apply); onInequal.term2().getVars().stream().filter(v -> !onInequal.universals().contains(v)).forEach(onVar::apply); - return Unit.unit; - }, - onNew -> { + break; + } + case CNew: { + CNew onNew = (CNew) constraint; onNew.scopeTerm().getVars().forEach(onVar::apply); onNew.datumTerm().getVars().forEach(onVar::apply); - return Unit.unit; - }, - onResolveQuery -> { + break; + } + case IResolveQuery: { + IResolveQuery onResolveQuery = (IResolveQuery) constraint; onResolveQuery.scopeTerm().getVars().forEach(onVar::apply); RuleUtil.vars(onResolveQuery.filter().getDataWF(), onVar); RuleUtil.vars(onResolveQuery.min().getDataEquiv(), onVar); onResolveQuery.resultTerm().getVars().forEach(onVar::apply); - return Unit.unit; - }, - onTellEdge -> { + break; + } + case CTellEdge: { + CTellEdge onTellEdge = (CTellEdge) constraint; onTellEdge.sourceTerm().getVars().forEach(onVar::apply); onTellEdge.targetTerm().getVars().forEach(onVar::apply); - return Unit.unit; - }, - onTermId -> { + break; + } + case CAstId: { + CAstId onTermId = (CAstId) constraint; onTermId.astTerm().getVars().forEach(onVar::apply); onTermId.idTerm().getVars().forEach(onVar::apply); - return Unit.unit; - }, - onTermProperty -> { + break; + } + case CAstProperty: { + CAstProperty onTermProperty = (CAstProperty) constraint; onTermProperty.idTerm().getVars().forEach(onVar::apply); onTermProperty.value().getVars().forEach(onVar::apply); - return Unit.unit; - }, - onTrue -> null, - onTry -> { + break; + } + case CTry: { + CTry onTry = (CTry) constraint; vars(onTry.constraint(), onVar); - return Unit.unit; - }, - onUser -> { + break; + } + case CUser: { + CUser onUser = (CUser) constraint; onUser.args().forEach(t -> t.getVars().forEach(onVar::apply)); - return Unit.unit; + break; } - )); - // @formatter:on - + case CFalse: + case CTrue: + break; + } } public static IConstraint exists(Iterable vars, IConstraint body) { @@ -704,9 +448,14 @@ public static IConstraint exists(Iterable vars, IConstraint body) { * otherwise, none */ public static Optional trivial(IConstraint constraint) { - return Optional.ofNullable( - constraint.match(Constraints.cases(c -> null, c -> null, c -> null, c -> null, c -> false, c -> null, - c -> null, c -> null, c -> null, c -> null, c -> null, c -> true, c -> null, c -> null))); + switch(constraint.constraintTag()) { + case CTrue: + return Optional.of(true); + case CFalse: + return Optional.of(false); + default: + return Optional.empty(); + } } } diff --git a/statix.solver/src/main/java/mb/statix/constraints/IResolveQuery.java b/statix.solver/src/main/java/mb/statix/constraints/IResolveQuery.java index b4a6327d6..3d7005156 100644 --- a/statix.solver/src/main/java/mb/statix/constraints/IResolveQuery.java +++ b/statix.solver/src/main/java/mb/statix/constraints/IResolveQuery.java @@ -18,33 +18,21 @@ public interface IResolveQuery extends IConstraint { ITerm resultTerm(); - R match(Cases cases); - R matchInResolution(ResolutionFunction1 onResolveQuery, ResolutionFunction1 onCompiledQuery) throws ResolutionException, InterruptedException; - @Override default R match(IConstraint.Cases cases) { - return cases.caseResolveQuery(this); + interface ResolutionFunction1 { + R apply(T t) throws ResolutionException, InterruptedException; } - @Override default R matchOrThrow(IConstraint.CheckedCases cases) throws E { - return cases.caseResolveQuery(this); + @Override default IConstraint.Tag constraintTag() { + return IConstraint.Tag.IResolveQuery; } - interface Cases extends Function1 { - - R caseResolveQuery(CResolveQuery q); - - R caseCompiledQuery(CCompiledQuery q); - - @Override default R apply(IResolveQuery q) { - return q.match(this); - } - - } - - interface ResolutionFunction1 { - R apply(T t) throws ResolutionException, InterruptedException; + enum Tag { + CResolveQuery, + CCompiledQuery } + Tag resolveQueryTag(); } diff --git a/statix.solver/src/main/java/mb/statix/solver/IConstraint.java b/statix.solver/src/main/java/mb/statix/solver/IConstraint.java index 490d4b25c..ebc2b52ab 100644 --- a/statix.solver/src/main/java/mb/statix/solver/IConstraint.java +++ b/statix.solver/src/main/java/mb/statix/solver/IConstraint.java @@ -65,10 +65,6 @@ default IConstraint withBodyCriticalEdges(@SuppressWarnings("unused") ICompleten throw new UnsupportedOperationException("Constraint does not support body critical edges."); } - R match(Cases cases); - - R matchOrThrow(CheckedCases cases) throws E; - Set.Immutable getVars(); Set.Immutable freeVars(); @@ -92,76 +88,23 @@ default IConstraint withBodyCriticalEdges(@SuppressWarnings("unused") ICompleten String toString(TermFormatter termToString); - interface Cases extends Function1 { - - R caseArith(CArith c); - - R caseConj(CConj c); - - R caseEqual(CEqual c); - - R caseExists(CExists c); - - R caseFalse(CFalse c); - - R caseInequal(CInequal c); - - R caseNew(CNew c); - - R caseResolveQuery(IResolveQuery c); - - R caseTellEdge(CTellEdge c); - - R caseTermId(CAstId c); - - R caseTermProperty(CAstProperty c); - - R caseTrue(CTrue c); - - R caseTry(CTry c); - - R caseUser(CUser c); - - @Override default R apply(IConstraint c) { - return c.match(this); - } - - } - - interface CheckedCases extends CheckedFunction1 { - - R caseArith(CArith c) throws E; - - R caseConj(CConj c) throws E; - - R caseEqual(CEqual c) throws E; - - R caseExists(CExists c) throws E; - - R caseFalse(CFalse c) throws E; - - R caseInequal(CInequal c) throws E; - - R caseNew(CNew c) throws E; - - R caseResolveQuery(IResolveQuery c) throws E; - - R caseTellEdge(CTellEdge c) throws E; - - R caseTermId(CAstId c) throws E; - - R caseTermProperty(CAstProperty c) throws E; - - R caseTrue(CTrue c) throws E; - - R caseTry(CTry c) throws E; - - R caseUser(CUser c) throws E; - - @Override default R apply(IConstraint c) throws E { - return c.matchOrThrow(this); - } - + Tag constraintTag(); + + enum Tag { + CArith, + CConj, + CEqual, + CExists, + CFalse, + CInequal, + CNew, + IResolveQuery, + CTellEdge, + CAstId, + CAstProperty, + CTrue, + CTry, + CUser } } diff --git a/statix.solver/src/main/java/mb/statix/solver/completeness/CompletenessUtil.java b/statix.solver/src/main/java/mb/statix/solver/completeness/CompletenessUtil.java index e5604c90c..9ef8435a6 100644 --- a/statix.solver/src/main/java/mb/statix/solver/completeness/CompletenessUtil.java +++ b/statix.solver/src/main/java/mb/statix/solver/completeness/CompletenessUtil.java @@ -20,6 +20,7 @@ import mb.scopegraph.oopsla20.reference.EdgeOrData; import mb.statix.constraints.CCompiledQuery; import mb.statix.constraints.CConj; +import mb.statix.constraints.CExists; import mb.statix.constraints.CNew; import mb.statix.constraints.CResolveQuery; import mb.statix.constraints.CTellEdge; @@ -41,44 +42,48 @@ public class CompletenessUtil { * Discover critical edges in constraint. The scopeTerm is not guaranteed to be ground or instantiated. */ static void criticalEdges(IConstraint constraint, Spec spec, Action2> criticalEdge) { - // @formatter:off - constraint.match(Constraints.cases( - onArith -> Unit.unit, - onConj -> { + switch(constraint.constraintTag()) { + case CConj: { + CConj onConj = (CConj) constraint; Constraints.disjoin(onConj).forEach(c -> criticalEdges(c, spec, criticalEdge)); - return Unit.unit; - }, - onEqual -> Unit.unit, - onExists -> { + break; + } + case CExists: { + CExists onExists = (CExists) constraint; criticalEdges(onExists.constraint(), spec, (s, l) -> { if(!onExists.vars().contains(s)) { criticalEdge.apply(s, l); } }); - return Unit.unit; - }, - onFalse -> Unit.unit, - onInequal -> Unit.unit, - onNew -> { + break; + } + case CNew: { + CNew onNew = (CNew) constraint; criticalEdge.apply(onNew.scopeTerm(), EdgeOrData.data()); - return Unit.unit; - }, - onResolveQuery -> Unit.unit, - onTellEdge -> { + break; + } + case CTellEdge: { + CTellEdge onTellEdge = (CTellEdge) constraint; criticalEdge.apply(onTellEdge.sourceTerm(), EdgeOrData.edge(onTellEdge.label())); - return Unit.unit; - }, - onTermId -> Unit.unit, - onTermProperty -> Unit.unit, - onTrue -> Unit.unit, - onTry -> Unit.unit, - onUser -> { + break; + } + case CUser: { + CUser onUser = (CUser) constraint; spec.scopeExtensions().get(onUser.name()).stream() - .forEach(il -> criticalEdge.apply(onUser.args().get(il._1()), EdgeOrData.edge(il._2()))); - return Unit.unit; + .forEach(il -> criticalEdge.apply(onUser.args().get(il._1()), EdgeOrData.edge(il._2()))); + break; } - )); - // @formatter:on + case CArith: + case CEqual: + case CFalse: + case CInequal: + case IResolveQuery: + case CAstId: + case CAstProperty: + case CTrue: + case CTry: + break; + } } /** @@ -144,16 +149,15 @@ static Rule precomputeCriticalEdges(Rule rule, SetMultimap> spec, Action2> criticalEdge) { - // @formatter:off - return constraint.match(Constraints.cases( - carith -> carith, - cconj -> { + switch(constraint.constraintTag()) { + case CConj: { + CConj cconj = (CConj) constraint; final IConstraint newLeft = precomputeCriticalEdges(cconj.left(), spec, criticalEdge); final IConstraint newRight = precomputeCriticalEdges(cconj.right(), spec, criticalEdge); return new CConj(newLeft, newRight, cconj.cause().orElse(null)); - }, - cequal -> cequal, - cexists -> { + } + case CExists: { + CExists cexists = (CExists) constraint; final ICompleteness.Transient bodyCriticalEdges = Completeness.Transient.of(); final IConstraint newBody = precomputeCriticalEdges(cexists.constraint(), spec, (s, l) -> { if(cexists.vars().contains(s)) { @@ -163,10 +167,9 @@ static IConstraint precomputeCriticalEdges(IConstraint constraint, SetMultimap cfalse, - cinequal -> cinequal, - cnew -> { + } + case CNew: { + CNew cnew = (CNew) constraint; final ICompleteness.Transient ownCriticalEdges = Completeness.Transient.of(); final ITerm scopeOrVar; if((scopeOrVar = scopeOrVar().match(cnew.scopeTerm()).orElse(null)) != null) { @@ -174,25 +177,27 @@ static IConstraint precomputeCriticalEdges(IConstraint constraint, SetMultimap { + } + case IResolveQuery: { + IResolveQuery iresolveQuery = (IResolveQuery) constraint; final QueryFilter newFilter = - new QueryFilter(iresolveQuery.filter().getLabelWF(), precomputeCriticalEdges(iresolveQuery.filter().getDataWF(), spec)); + new QueryFilter(iresolveQuery.filter().getLabelWF(), precomputeCriticalEdges(iresolveQuery.filter().getDataWF(), spec)); final QueryMin newMin = - new QueryMin(iresolveQuery.min().getLabelOrder(), precomputeCriticalEdges(iresolveQuery.min().getDataEquiv(), spec)); - return iresolveQuery.match(new IResolveQuery.Cases() { - - @Override public IResolveQuery caseResolveQuery(CResolveQuery q) { + new QueryMin(iresolveQuery.min().getLabelOrder(), precomputeCriticalEdges(iresolveQuery.min().getDataEquiv(), spec)); + switch(iresolveQuery.resolveQueryTag()) { + case CResolveQuery: { CResolveQuery q = (CResolveQuery) iresolveQuery; return new CResolveQuery(newFilter, newMin, q.scopeTerm(), q.resultTerm(), - q.cause().orElse(null), q.message().orElse(null)); + q.cause().orElse(null), q.message().orElse(null)); } - @Override public IResolveQuery caseCompiledQuery(CCompiledQuery q) { + case CCompiledQuery: { CCompiledQuery q = (CCompiledQuery) iresolveQuery; return new CCompiledQuery(newFilter, newMin, q.scopeTerm(), q.resultTerm(), - q.cause().orElse(null), q.message().orElse(null), q.stateMachine()); - }}); - }, - ctellEdge -> { + q.cause().orElse(null), q.message().orElse(null), q.stateMachine()); + } + } + } + case CTellEdge: { + CTellEdge ctellEdge = (CTellEdge) constraint; final ICompleteness.Transient ownCriticalEdges = Completeness.Transient.of(); final ITerm scopeOrVar; if((scopeOrVar = scopeOrVar().match(ctellEdge.sourceTerm()).orElse(null)) != null) { @@ -200,16 +205,15 @@ static IConstraint precomputeCriticalEdges(IConstraint constraint, SetMultimap ctermId, - ctermProperty -> ctermProperty, - ctrue -> ctrue, - ctry -> { + ctellEdge.cause().orElse(null), ownCriticalEdges.freeze()); + } + case CTry: { + CTry ctry = (CTry) constraint; final IConstraint newBody = precomputeCriticalEdges(ctry.constraint(), spec, criticalEdge); return new CTry(newBody, ctry.cause().orElse(null), ctry.message().orElse(null)); - }, - cuser -> { + } + case CUser: { + CUser cuser = (CUser) constraint; final ICompleteness.Transient ownCriticalEdges = Completeness.Transient.of(); spec.get(cuser.name()).stream().forEach(il -> { final ITerm scopeOrVar; @@ -221,8 +225,18 @@ static IConstraint precomputeCriticalEdges(IConstraint constraint, SetMultimap() { - - @Override public Boolean caseArith(CArith c) throws InterruptedException { + switch(constraint.constraintTag()) { + case CArith: { CArith c = (CArith) constraint; final IUniDisunifier unifier = state.unifier(); final Optional term1 = c.expr1().isTerm(); final Optional term2 = c.expr2().isTerm(); @@ -353,11 +352,11 @@ private boolean k(IConstraint constraint, int fuel) throws InterruptedException } } - @Override public Boolean caseConj(CConj c) throws InterruptedException { + case CConj: { CConj c = (CConj) constraint; return success(c, state, NO_UPDATED_VARS, disjoin(c), NO_NEW_CRITICAL_EDGES, NO_EXISTENTIALS, fuel); } - @Override public Boolean caseEqual(CEqual c) throws InterruptedException { + case CEqual: { CEqual c = (CEqual) constraint; final ITerm term1 = c.term1(); final ITerm term2 = c.term2(); IDebugContext debug = params.debug(); @@ -389,7 +388,7 @@ private boolean k(IConstraint constraint, int fuel) throws InterruptedException } } - @Override public Boolean caseExists(CExists c) throws InterruptedException { + case CExists: { CExists c = (CExists) constraint; final Renaming.Builder _existentials = Renaming.builder(); IState.Immutable newState = state; for(ITermVar var : c.vars()) { @@ -412,11 +411,11 @@ private boolean k(IConstraint constraint, int fuel) throws InterruptedException existentials.asMap(), fuel); } - @Override public Boolean caseFalse(CFalse c) { + case CFalse: { CFalse c = (CFalse) constraint; return fail(c); } - @Override public Boolean caseInequal(CInequal c) throws InterruptedException { + case CInequal: { CInequal c = (CInequal) constraint; final ITerm term1 = c.term1(); final ITerm term2 = c.term2(); IDebugContext debug = params.debug(); @@ -442,7 +441,7 @@ private boolean k(IConstraint constraint, int fuel) throws InterruptedException } } - @Override public Boolean caseNew(CNew c) throws InterruptedException { + case CNew: { CNew c = (CNew) constraint; IState.Immutable newState = state; final ITerm scopeTerm = c.scopeTerm(); @@ -462,7 +461,7 @@ private boolean k(IConstraint constraint, int fuel) throws InterruptedException NO_EXISTENTIALS, fuel); } - @Override public Boolean caseResolveQuery(IResolveQuery c) throws InterruptedException { + case IResolveQuery: { IResolveQuery c = (IResolveQuery) constraint; final QueryFilter filter = c.filter(); final QueryMin min = c.min(); final ITerm scopeTerm = c.scopeTerm(); @@ -534,7 +533,7 @@ private boolean k(IConstraint constraint, int fuel) throws InterruptedException } } - @Override public Boolean caseTellEdge(CTellEdge c) throws InterruptedException { + case CTellEdge: { CTellEdge c = (CTellEdge) constraint; final ITerm sourceTerm = c.sourceTerm(); final ITerm label = c.label(); final ITerm targetTerm = c.targetTerm(); @@ -565,7 +564,7 @@ private boolean k(IConstraint constraint, int fuel) throws InterruptedException NO_NEW_CRITICAL_EDGES, NO_EXISTENTIALS, fuel); } - @Override public Boolean caseTermId(CAstId c) throws InterruptedException { + case CAstId: { CAstId c = (CAstId) constraint; final ITerm term = c.astTerm(); final ITerm idTerm = c.idTerm(); @@ -593,7 +592,7 @@ private boolean k(IConstraint constraint, int fuel) throws InterruptedException } } - @Override public Boolean caseTermProperty(CAstProperty c) throws InterruptedException { + case CAstProperty: { CAstProperty c = (CAstProperty) constraint; final ITerm idTerm = c.idTerm(); final ITerm prop = c.property(); final ITerm value = c.value(); @@ -642,12 +641,12 @@ private boolean k(IConstraint constraint, int fuel) throws InterruptedException } } - @Override public Boolean caseTrue(CTrue c) throws InterruptedException { + case CTrue: { CTrue c = (CTrue) constraint; return success(c, state, NO_UPDATED_VARS, NO_NEW_CONSTRAINTS, NO_NEW_CRITICAL_EDGES, NO_EXISTENTIALS, fuel); } - @Override public Boolean caseTry(CTry c) throws InterruptedException { + case CTry: { CTry c = (CTry) constraint; final IDebugContext debug = params.debug(); try { if(Solver.entails(spec, state, c.constraint(), params::isComplete, new NullDebugContext(), @@ -663,7 +662,7 @@ private boolean k(IConstraint constraint, int fuel) throws InterruptedException } } - @Override public Boolean caseUser(CUser c) throws InterruptedException { + case CUser: { CUser c = (CUser) constraint; final String name = c.name(); final List args = c.args(); @@ -694,7 +693,9 @@ private boolean k(IConstraint constraint, int fuel) throws InterruptedException applyResult.criticalEdges(), NO_EXISTENTIALS, fuel); } - }); + } + // N.B. don't use this in default case branch, instead use IDE to catch non-exhaustive switch statements + throw new RuntimeException("Missing case for IConstraint subclass/tag"); } diff --git a/statix.solver/src/main/java/mb/statix/spec/PreSolvedConstraint.java b/statix.solver/src/main/java/mb/statix/spec/PreSolvedConstraint.java index ac9d21f0d..8e2aca37c 100644 --- a/statix.solver/src/main/java/mb/statix/spec/PreSolvedConstraint.java +++ b/statix.solver/src/main/java/mb/statix/spec/PreSolvedConstraint.java @@ -36,8 +36,11 @@ import mb.nabl2.terms.unification.ud.IUniDisunifier; import mb.nabl2.terms.unification.ud.PersistentUniDisunifier; import mb.nabl2.util.TermFormatter; +import mb.statix.constraints.CConj; +import mb.statix.constraints.CEqual; import mb.statix.constraints.CExists; import mb.statix.constraints.CFalse; +import mb.statix.constraints.CInequal; import mb.statix.constraints.Constraints; import mb.statix.solver.Delay; import mb.statix.solver.IConstraint; @@ -432,29 +435,37 @@ public static void preSolve(IConstraint constraint, Function1cases( - carith -> { constraints.add(c.withCause(cause)); return true; }, - conj -> { worklist.addAll(Constraints.disjoin(conj)); return true; }, - cequal -> { + boolean okay = true; + switch(c.constraintTag()) { + case CConj: { + CConj conj = (CConj) c; + worklist.addAll(Constraints.disjoin(conj)); + break; + } + case CEqual: { + CEqual cequal = (CEqual) c; try { final IUnifier.Immutable result; if((result = unifier.unify(cequal.term1(), cequal.term2(), isRigid).orElse(null)) == null) { failures.add(cequal.withCause(cause)); - return false; + okay = false; + break; } updatedVars.addAll(result.domainSet()); bodyCriticalEdges.updateAll(result.domainSet(), result); - return true; + break; } catch(OccursException e) { failures.add(cequal.withCause(cause)); - return false; + okay = false; + break; } catch(RigidException e) { delays.put(cequal, Delay.ofVars(e.vars())); - return false; + okay = false; + break; } - }, - cexists -> { + } + case CExists: { + CExists cexists = (CExists) c; final IRenaming renaming = fresh.apply(cexists.vars()/*FIXME possible opt: .__retainAll(cexists.constraint().freeVars())*/); if(first.get()) { existentials.putAll(renaming.asMap()); @@ -463,35 +474,45 @@ public static void preSolve(IConstraint constraint, Function1 { bodyCriticalEdges.addAll(bce.apply(renaming), unifier); }); - return true; - }, - cfalse -> { + break; + } + case CFalse: { + CFalse cfalse = (CFalse) c; failures.add(cfalse.withCause(cause)); - return false; - }, - cinequal -> { + okay = false; + break; + } + case CInequal: { + CInequal cinequal = (CInequal) c; try { if(!unifier.disunify(cinequal.universals(), cinequal.term1(), cinequal.term2(), isRigid).isPresent()) { failures.add(cinequal.withCause(cause)); - return false; + okay = false; + break; } - } catch (RigidException e) { + break; + } catch(RigidException e) { delays.put(cinequal, Delay.ofVars(e.vars())); - return false; + okay = false; + break; } - return true; - }, - cnew -> { constraints.add(c.withCause(cause)); return true; }, - cquery -> { constraints.add(c.withCause(cause)); return true; }, - ctelledge -> { constraints.add(c.withCause(cause)); return true; }, - castid -> { constraints.add(c.withCause(cause)); return true; }, - castprop -> { constraints.add(c.withCause(cause)); return true; }, - ctrue -> { return true; }, - ctry -> { constraints.add(c.withCause(cause)); return true; }, - cuser -> { constraints.add(c.withCause(cause)); return true; } - )); + } + case CTrue: { + break; + } + case CArith: + case CNew: + case IResolveQuery: + case CTellEdge: + case CAstId: + case CAstProperty: + case CTry: + case CUser: { + constraints.add(c.withCause(cause)); + break; + } + } first.set(false); - // @formatter:on if(!okay && returnOnFirstErrorOrDelay) { return; } diff --git a/statix.solver/src/main/java/mb/statix/spoofax/StatixPrimitive.java b/statix.solver/src/main/java/mb/statix/spoofax/StatixPrimitive.java index a1347a819..74f2a8f72 100644 --- a/statix.solver/src/main/java/mb/statix/spoofax/StatixPrimitive.java +++ b/statix.solver/src/main/java/mb/statix/spoofax/StatixPrimitive.java @@ -43,6 +43,7 @@ import mb.nabl2.terms.substitution.PersistentSubstitution; import mb.nabl2.terms.unification.ud.IUniDisunifier; import mb.nabl2.util.TermFormatter; +import mb.statix.constraints.CUser; import mb.statix.constraints.Constraints; import mb.statix.constraints.messages.IMessage; import mb.statix.solver.IConstraint; @@ -236,27 +237,13 @@ public static Tuple2, ITerm> formatMessage(final IMessage messa } private static Optional findOriginArgument(IConstraint constraint, IUniDisunifier unifier) { - // @formatter:off - final Function1> terms = Constraints.cases( - onArith -> Stream.empty(), - onConj -> Stream.empty(), - onEqual -> Stream.empty(), - onExists -> Stream.empty(), - onFalse -> Stream.empty(), - onInequal -> Stream.empty(), - onNew -> Stream.empty(), - onResolveQuery -> Stream.empty(), - onTellEdge -> Stream.empty(), - onTermId -> Stream.empty(), - onTermProperty -> Stream.empty(), - onTrue -> Stream.empty(), - onTry -> Stream.empty(), - onUser -> onUser.args().stream() - ); - return terms.apply(constraint) + if(constraint.constraintTag() == IConstraint.Tag.CUser) { + CUser onUser = (CUser) constraint; + return onUser.args().stream() .flatMap(t -> Streams.stream(getOriginTerm(t, unifier))) .findFirst(); - // @formatter:on + } + return Optional.empty(); } private static Optional getOriginTerm(ITerm term, IUniDisunifier unifier) {