|
|
import os |
|
|
import glob |
|
|
from pathlib import Path |
|
|
|
|
|
import gradio as gr |
|
|
import jpype |
|
|
import jpype.imports |
|
|
import pandas as pd |
|
|
|
|
|
from graphviz import Digraph |
|
|
from jpype import JClass, getDefaultJVMPath |
|
|
from pslpython.partition import Partition |
|
|
|
|
|
def _find_psl_jars() -> list[str]: |
|
|
""" |
|
|
Priority: |
|
|
1) Any *.jar inside PSL_JARS_DIR (if set) |
|
|
2) Any *.jar in ./jars next to this script |
|
|
3) Installed pslpython runtime jar |
|
|
""" |
|
|
jars: list[str] = [] |
|
|
|
|
|
|
|
|
this_dir = Path(__file__).resolve().parent |
|
|
jars_dir = f"{this_dir}/.jars" |
|
|
jars.extend(glob.glob(f"{jars_dir}/*.jar")) |
|
|
|
|
|
|
|
|
dedup = [] |
|
|
seen = set() |
|
|
for j in jars: |
|
|
if j not in seen and Path(j).is_file(): |
|
|
seen.add(j) |
|
|
dedup.append(j) |
|
|
|
|
|
return dedup |
|
|
|
|
|
def start_psl_jvm(verbose: bool = True) -> list[str]: |
|
|
""" |
|
|
Start a JVM with a classpath that includes PSL jars. |
|
|
Returns the list of jars used. |
|
|
""" |
|
|
if jpype.isJVMStarted(): |
|
|
if verbose: |
|
|
print("[PSL] JVM already started.") |
|
|
return [] |
|
|
|
|
|
jars = _find_psl_jars() |
|
|
if not jars: |
|
|
raise RuntimeError( |
|
|
"No PSL jars found. Place jars under /jars" |
|
|
) |
|
|
|
|
|
classpath = os.pathsep.join(jars) |
|
|
|
|
|
|
|
|
jvm_path = getDefaultJVMPath() |
|
|
|
|
|
|
|
|
jpype.startJVM( |
|
|
jvm_path, |
|
|
f"-Djava.class.path={classpath}", |
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
GA = JClass("org.linqs.psl.model.atom.GroundAtom") |
|
|
if verbose: |
|
|
print(f"[PSL] JVM started with {len(jars)} jars.") |
|
|
for j in jars: |
|
|
print(f" - {j}") |
|
|
print(f"[PSL] Sanity check: loaded {GA}") |
|
|
|
|
|
return jars |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import io |
|
|
import gradio as gr |
|
|
import pandas as pd |
|
|
from graphviz import Digraph |
|
|
|
|
|
import pslpython |
|
|
from pslpython.model import Model |
|
|
from pslpython.predicate import Predicate |
|
|
from pslpython.rule import Rule |
|
|
|
|
|
MODEL_NAME = 'minimal-circuit' |
|
|
|
|
|
ADDITIONAL_PSL_OPTIONS = { |
|
|
'runtime.log.level': 'INFO', |
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
RULES = [ |
|
|
{"name":"R1_perfdrop_implies_critical", "weight":4.0, |
|
|
"body":["InCircuit(E,C)", "PerfDrop(E,C)"], "head":"Critical(E,C)", "squared":True, |
|
|
"comment":"High performance drop on removal implies the edge is critical (necessary) for C."}, |
|
|
|
|
|
{"name":"R2_safe_implies_removable", "weight":1.0, |
|
|
"body":["InCircuit(E,C)", "SafeToRemove(E,C)"], "head":"Removable(E,C)", "squared":True, |
|
|
"comment":"If an edge is in C and appears safe to remove, infer it’s removable."}, |
|
|
|
|
|
{"name":"R3_removable_excludes_critical", "weight":0.6, |
|
|
"body":["Removable(E,C)"], "head":"!Critical(E,C)", "squared":True, |
|
|
"comment":"Edges inferred removable should not be marked critical (soft mutual exclusion)."}, |
|
|
|
|
|
{"name":"R4_minimal_excludes_removable", "weight":1.0, |
|
|
"body":["Minimal(C)", "InCircuit(E,C)"], "head":"!Removable(E,C)", "squared":True, |
|
|
"comment":"A minimal circuit should not contain removable edges."}, |
|
|
|
|
|
{"name":"R5_remmasshi_neg_minimal", "weight":4.5, |
|
|
"body":["RemMassHi(C)"], "head":"!Minimal(C)", "squared":True, |
|
|
"comment":"High total removable mass argues against minimality."}, |
|
|
|
|
|
{"name":"R6_dominates_neg_minimal", "weight":8.0, |
|
|
"body":["Dominates(C2,C)"], "head":"!Minimal(C)", "squared":True, |
|
|
"comment":"If a cheaper sufficient rival dominates C, C is not minimal."}, |
|
|
|
|
|
{"name":"R7_insufficient_neg_minimal", "weight":8.0, |
|
|
"body":["!Sufficient(C)"], "head":"!Minimal(C)", "squared":True, |
|
|
"comment":"Minimality applies only to sufficient circuits."}, |
|
|
|
|
|
{"name":"R8_rand_circuit_better_neg_minimal", "weight":4.0, |
|
|
"body":["RandCircuitBetter(C,Cr)"], "head":"!Minimal(C)", "squared":True, |
|
|
"comment":"If random comparable circuits often outperform C, it is unlikely minimal."}, |
|
|
|
|
|
{"name":"R9_tie_uncovered_neg_minimal", "weight":1.2, |
|
|
"body":["Tie(E1,E2,C)", "!TieCovered(E1,E2,C)"], "head":"!Minimal(C)", "squared":True, |
|
|
"comment":"Near-ties must be explored; uncovered ties weaken claims of minimality."}, |
|
|
|
|
|
{"name":"R10_multiple_startseed_minimal", "weight":0.5, |
|
|
"body":["FromMultipleStartSeeds(C)"], "head":"Minimal(C)", "squared":True, |
|
|
"comment":"Consistent convergence to the same C from multiple starts nudges minimality upward."}, |
|
|
|
|
|
|
|
|
{"name":"HC1_dominates_implies_not_minimal", "hard":True, |
|
|
"body":["Dominates(C2,C)"], "head":"!Minimal(C)", |
|
|
"comment":"HARD: Whenever C2 dominates C, C cannot be minimal."}, |
|
|
] |
|
|
|
|
|
|
|
|
def _literal_to_psl(lit: str) -> str: |
|
|
return lit |
|
|
|
|
|
def _rule_to_psl(rule: dict) -> str: |
|
|
body = " & ".join(_literal_to_psl(l) for l in rule["body"]) |
|
|
head = _literal_to_psl(rule["head"]) |
|
|
if rule.get("hard", False): |
|
|
return f"{body} -> {head} ." |
|
|
weight = rule["weight"] |
|
|
exp = " ^2" if rule.get("squared", False) else "" |
|
|
return f"{weight}: {body} -> {head}{exp}" |
|
|
|
|
|
def add_rules(model: Model, rules=RULES): |
|
|
for r in rules: |
|
|
model.add_rule(Rule(_rule_to_psl(r))) |
|
|
|
|
|
def _pred_name(lit: str) -> str: |
|
|
return lit.lstrip('!').split('(')[0].strip() |
|
|
|
|
|
def _is_negated(lit: str) -> bool: |
|
|
return lit.strip().startswith('!') |
|
|
|
|
|
def rules_to_graphviz_file(rules=RULES, basename: str = "rules_graph") -> str: |
|
|
""" |
|
|
Render Graphviz to a PNG on disk and return the filepath. |
|
|
Produces 'basename.png' next to your app. |
|
|
""" |
|
|
g = Digraph(name="CircuitMinimality", format="png", engine="dot") |
|
|
g.attr(rankdir="LR", fontname="Helvetica") |
|
|
g.node_attr.update(shape="box", style="rounded,filled", fillcolor="#f8f8f8", fontname="Helvetica") |
|
|
|
|
|
def pred_name(lit: str) -> str: |
|
|
return lit.lstrip('!').split('(')[0].strip() |
|
|
|
|
|
def is_negated(lit: str) -> bool: |
|
|
return lit.strip().startswith('!') |
|
|
|
|
|
|
|
|
preds = set() |
|
|
for r in rules: |
|
|
preds.update(pred_name(x) for x in (r["body"] + [r["head"]])) |
|
|
for p in sorted(preds): |
|
|
g.node(p) |
|
|
|
|
|
|
|
|
for r in rules: |
|
|
head_lit = r["head"] |
|
|
head_pred = pred_name(head_lit) |
|
|
neg_head = is_negated(head_lit) |
|
|
color = "#2ca02c" if not neg_head else "#d62728" |
|
|
style = "solid" |
|
|
label = f"{r['name']} ({'HARD' if r.get('hard', False) else r.get('weight', '')})".strip() |
|
|
for b in r["body"]: |
|
|
g.edge(pred_name(b), head_pred, color=color, fontcolor=color, style=style, penwidth="2", label=label) |
|
|
|
|
|
|
|
|
out = g.render(filename=basename, cleanup=True) |
|
|
png_path = out if out.endswith(".png") else f"{out}.png" |
|
|
return png_path |
|
|
|
|
|
def _rules_commentary_md(rules=RULES) -> str: |
|
|
lines = ["### Rule Commentary", ""] |
|
|
for r in rules: |
|
|
badge = "**HARD**" if r.get("hard", False) else f"**w={r.get('weight','')}**" |
|
|
psl = _rule_to_psl(r) |
|
|
lines.append(f"- **{r['name']}** ({badge}) \n {r['comment']} \n ` {psl} `") |
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
used_jars = start_psl_jvm(verbose=False) |
|
|
jvm_already = jpype.isJVMStarted() and (len(used_jars) == 0) |
|
|
JVM_STATUS = "✅ JVM already running." if jvm_already else f"✅ JVM started with {len(used_jars)} jar(s)." |
|
|
if used_jars: |
|
|
JVM_STATUS += "\n" + "\n".join([f"- {p}" for p in used_jars]) |
|
|
|
|
|
_ = JClass("org.linqs.psl.model.atom.GroundAtom") |
|
|
except Exception as e: |
|
|
JVM_STATUS = f"❌ JVM start failed: {e}" |
|
|
|
|
|
|
|
|
|
|
|
COMMENTARY_MD = _rules_commentary_md(RULES) |
|
|
|
|
|
def add_predicates(model): |
|
|
|
|
|
model.add_predicate(Predicate('InCircuit', size=2)) |
|
|
model.add_predicate(Predicate('Sufficient', size=1)) |
|
|
model.add_predicate(Predicate('PerfDrop', size=2)) |
|
|
model.add_predicate(Predicate('SafeToRemove', size=2)) |
|
|
model.add_predicate(Predicate('RemMassHi', size=1)) |
|
|
|
|
|
|
|
|
model.add_predicate(Predicate('Removable', size=2)) |
|
|
model.add_predicate(Predicate('Critical', size=2)) |
|
|
model.add_predicate(Predicate('Minimal', size=1)) |
|
|
|
|
|
|
|
|
model.add_predicate(Predicate('Dominates', size=2)) |
|
|
model.add_predicate(Predicate('RandCircuitBetter', size=2)) |
|
|
|
|
|
|
|
|
model.add_predicate(Predicate('Tie', size=3)) |
|
|
model.add_predicate(Predicate('TieCovered', size=3)) |
|
|
model.add_predicate(Predicate('FromMultipleStartSeeds', size=1)) |
|
|
|
|
|
def add_rules(model, rules=RULES, attach_comments: bool = False): |
|
|
"""Add rules to a pslpython model from the RULES dict.""" |
|
|
for r in rules: |
|
|
psl_text = _rule_to_psl(r).replace("->", "->").replace("!", "!") |
|
|
model.add_rule(Rule(psl_text)) |
|
|
if attach_comments and r.get("comment"): |
|
|
print(f"# {r['name']}: {r['comment']}") |
|
|
|
|
|
def infer(model): |
|
|
"""Placeholder inference call (can later add data).""" |
|
|
return model.infer(psl_options=ADDITIONAL_PSL_OPTIONS) |
|
|
|
|
|
def build_model(model_name=MODEL_NAME): |
|
|
"""Build a full PSL model from scratch.""" |
|
|
model = Model(model_name) |
|
|
add_predicates(model) |
|
|
add_rules(model) |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = build_model() |
|
|
|
|
|
def summarize_model(model: Model): |
|
|
"""Return Markdown summary of predicates and rules.""" |
|
|
preds = model.get_predicates() |
|
|
lines = [ |
|
|
"### PSL Model Build Summary", |
|
|
f"- Model name: **{model._name}**", |
|
|
f"- Total predicates: {len(preds)}", |
|
|
f"- Total rules: {len(RULES)}", |
|
|
"", |
|
|
"#### Predicates:" |
|
|
] |
|
|
for pred_name, p in preds.items(): |
|
|
lines.append(f"- `{pred_name}` (arity = {len(p._types)})") |
|
|
lines.append("\n#### Rules:") |
|
|
for r in RULES: |
|
|
desc = r['comment'] |
|
|
lines.append(f"- **{r['name']}** — {desc}") |
|
|
return "\n".join(lines) |
|
|
|
|
|
MODEL_SUMMARY_MD = summarize_model(model) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
GRAPH_PATH = rules_to_graphviz_file(RULES, basename="rules_graph") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ATOMS = { |
|
|
|
|
|
"Minimal": { |
|
|
"OBS": pd.DataFrame({"C": ["C4", "C5"], "VALUE": [1.0, 0.0]}), |
|
|
"TARGETS": pd.DataFrame({"C": ["C1", "C2", "C3"], "VALUE": [0.5, 0.5, 0.5]}), |
|
|
"TRUTH": pd.DataFrame({"C": ["C3"], "VALUE": [0.2]}), |
|
|
}, |
|
|
|
|
|
|
|
|
"Sufficient": { |
|
|
"OBS": pd.DataFrame({"C": ["C1", "C2", "C3"], "VALUE": [1.0, 1.0, 0.0]}), |
|
|
"TARGETS": pd.DataFrame({"C": ["C4"], "VALUE": [0.5]}), |
|
|
"TRUTH": pd.DataFrame({"C": ["C1", "C2", "C3", "C4"],"VALUE": [1.0, 1.0, 0.0, 1.0]}), |
|
|
}, |
|
|
|
|
|
|
|
|
"RemMassHi": { |
|
|
"OBS": pd.DataFrame({"C": ["C2"], "VALUE": [1.0]}), |
|
|
"TARGETS": pd.DataFrame({"C": ["C1", "C3"], "VALUE": [0.5, 0.5]}), |
|
|
}, |
|
|
|
|
|
|
|
|
"FromMultipleStartSeeds": { |
|
|
"OBS": pd.DataFrame({"C": ["C1", "C2", "C3"], "VALUE": [1.0, 1.0, 0.0]}), |
|
|
}, |
|
|
|
|
|
|
|
|
"InCircuit": { |
|
|
"OBS": pd.DataFrame({"E": ["e1","e2","e3"], "C": ["C1","C1","C1"], "VALUE": [1.0, 1.0, 1.0]}), |
|
|
"TARGETS": pd.DataFrame({"E": ["e4","e5"], "C": ["C2","C2"], "VALUE": [0.5, 0.5]}), |
|
|
}, |
|
|
|
|
|
|
|
|
"PerfDrop": { |
|
|
"OBS": pd.DataFrame({"E": ["e1","e2","e3","e4"], "C": ["C1","C1","C1","C2"], "VALUE": [0.40, 0.10, 0.50, 0.20]}), |
|
|
"TARGETS": pd.DataFrame({"E": ["e5","e6"], "C": ["C2","C2"], "VALUE": [0.5, 0.5]}), |
|
|
}, |
|
|
|
|
|
|
|
|
"SafeToRemove": { |
|
|
"OBS": pd.DataFrame({"E": ["e2","e4","e5"], "C": ["C1","C2","C2"], "VALUE": [0.90, 0.80, 0.30]}), |
|
|
"TARGETS": pd.DataFrame({"E": ["e1","e3"], "C": ["C1","C1"], "VALUE": [0.5, 0.5]}), |
|
|
}, |
|
|
|
|
|
|
|
|
"Removable": { |
|
|
"TARGETS": pd.DataFrame({ |
|
|
"E": ["e1","e2","e3","e5"], |
|
|
"C": ["C1","C1","C1","C2"], |
|
|
"VALUE": [0.5, 0.5, 0.5, 0.5], |
|
|
}), |
|
|
"OBS": pd.DataFrame({"E": ["e4"], "C": ["C2"], "VALUE": [0.2]}), |
|
|
}, |
|
|
|
|
|
|
|
|
"Critical": { |
|
|
"TARGETS": pd.DataFrame({ |
|
|
"E": ["e1","e2","e3","e4","e5","e6"], |
|
|
"C": ["C1","C1","C1","C2","C2","C2"], |
|
|
"VALUE": [0.5, 0.5, 0.5, 0.5, 0.5, 0.5], |
|
|
}), |
|
|
}, |
|
|
|
|
|
|
|
|
"Dominates": { |
|
|
"OBS": pd.DataFrame({"C2": ["C2","C3"], "C": ["C1","C1"], "VALUE": [1.0, 0.0]}), |
|
|
"TARGETS": pd.DataFrame({"C2": ["C3"], "C": ["C2"], "VALUE": [0.5]}), |
|
|
}, |
|
|
|
|
|
|
|
|
"RandCircuitBetter": { |
|
|
"OBS": pd.DataFrame({"C": ["C1","C2"], "Cr": ["R1","R1"], "VALUE": [1.0, 0.0]}), |
|
|
"TARGETS": pd.DataFrame({"C": ["C1"], "Cr": ["R2"], "VALUE": [0.5]}), |
|
|
}, |
|
|
|
|
|
|
|
|
"Tie": { |
|
|
"OBS": pd.DataFrame({"E1": ["e1","e2","e3"], "E2": ["e2","e3","e4"], "C": ["C1","C1","C1"], "VALUE": [1.0, 0.0, 0.2]}), |
|
|
"TARGETS": pd.DataFrame({"E1": ["e5","e6"], "E2": ["e6","e7"], "C": ["C2","C2"], "VALUE": [0.5, 0.5]}), |
|
|
}, |
|
|
|
|
|
|
|
|
"TieCovered": { |
|
|
"OBS": pd.DataFrame({"E1": ["e1","e2"], "E2": ["e2","e3"], "C": ["C1","C1"], "VALUE": [1.0, 1.0]}), |
|
|
"TARGETS": pd.DataFrame({"E1": ["e3","e5","e6"], "E2": ["e4","e6","e7"], "C": ["C1","C2","C2"], "VALUE": [0.5, 0.5, 0.5]}), |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
_PARTITION_MAP = { |
|
|
"OBS": Partition.OBSERVATIONS, |
|
|
"TARGETS": Partition.TARGETS, |
|
|
"TRUTH": Partition.TRUTH, |
|
|
} |
|
|
|
|
|
def _rename_cols_for_psl(df: pd.DataFrame) -> pd.DataFrame: |
|
|
""" |
|
|
Convert human-friendly columns to PSL's required integer columns: |
|
|
args -> 0..k-1 in order of appearance, last column = k is VALUE. |
|
|
""" |
|
|
cols = list(df.columns) |
|
|
assert "VALUE" in cols, "Each frame must include a VALUE column." |
|
|
arg_cols = [c for c in cols if c != "VALUE"] |
|
|
new_cols = {col: i for i, col in enumerate(arg_cols)} |
|
|
new_cols["VALUE"] = len(arg_cols) |
|
|
return df.rename(columns=new_cols)[list(range(len(arg_cols)+1))] |
|
|
|
|
|
def load_atoms_config(model, atoms=ATOMS): |
|
|
"""Load the ATOMS config into the PSL model, one predicate at a time.""" |
|
|
for pred_name, parts in atoms.items(): |
|
|
pred = model.get_predicate(pred_name) |
|
|
for part_key, df in parts.items(): |
|
|
partition = _PARTITION_MAP[part_key] |
|
|
pred.add_data(partition, _rename_cols_for_psl(df)) |
|
|
|
|
|
|
|
|
|
|
|
model = build_model() |
|
|
load_atoms_config(model, ATOMS) |
|
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
def atoms_summary_table(atoms=ATOMS) -> pd.DataFrame: |
|
|
rows = [] |
|
|
for pname, parts in atoms.items(): |
|
|
obs = len(parts.get("OBS", pd.DataFrame())) |
|
|
tgt = len(parts.get("TARGETS", pd.DataFrame())) |
|
|
tru = len(parts.get("TRUTH", pd.DataFrame())) |
|
|
rows.append({"Predicate": pname, "Observations": obs, "Targets": tgt, "Truth": tru}) |
|
|
df = pd.DataFrame(rows).sort_values("Predicate").reset_index(drop=True) |
|
|
return df |
|
|
|
|
|
ATOMS_TABLE = atoms_summary_table(ATOMS) |
|
|
|
|
|
ATOMS_COMMENTARY = ( |
|
|
"### Facts/Atoms Overview\n" |
|
|
"- **Observations** are known inputs (fixed evidence).\n" |
|
|
"- **Targets** are latent atoms the model will **infer** (e.g., whether `C1`, `C2`, `C3` are `Minimal`).\n" |
|
|
"- **Truth** (if provided) is held-out gold used to **evaluate** performance (not used during inference).\n" |
|
|
"_Note: the same ground atom must not appear in both Observations and Targets._" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_rules_and_weights(model): |
|
|
""" |
|
|
Returns a dict: textual rule body -> weight (float). |
|
|
Uses pslpython internals (_rules, _weight, _rule_body) as in your snippet. |
|
|
""" |
|
|
rules = model._rules |
|
|
return {r._rule_body: r._weight for r in rules} |
|
|
|
|
|
|
|
|
start_rules_to_weights = get_rules_and_weights(model) |
|
|
|
|
|
|
|
|
|
|
|
model.learn(psl_options=ADDITIONAL_PSL_OPTIONS) |
|
|
|
|
|
results = model.infer(psl_options=ADDITIONAL_PSL_OPTIONS) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
named_results = {} |
|
|
for pred_obj, df in results.items(): |
|
|
|
|
|
name = pred_obj.name() if hasattr(pred_obj, "name") and callable(pred_obj.name) else getattr(pred_obj, "name", str(pred_obj)) |
|
|
named_results[str(name).upper()] = df |
|
|
|
|
|
|
|
|
minimal_results = named_results.get("MINIMAL", pd.DataFrame()) |
|
|
|
|
|
if not minimal_results.empty: |
|
|
cols = list(minimal_results.columns) |
|
|
minimal_results_display = minimal_results[[0, "truth"]].sort_values(0).reset_index(drop=True) |
|
|
|
|
|
minimal_results_display = minimal_results_display.rename(columns={"truth": "VALUE", "O": "Circuit_ID"}) |
|
|
else: |
|
|
minimal_results_display = pd.DataFrame(columns=["C", "VALUE"]) |
|
|
|
|
|
|
|
|
end_rules_to_weights = get_rules_and_weights(model) |
|
|
|
|
|
def weights_comparison_df(start_w: dict, end_w: dict) -> pd.DataFrame: |
|
|
|
|
|
keys = sorted(set(start_w.keys()) | set(end_w.keys())) |
|
|
rows = [] |
|
|
for k in keys: |
|
|
sw = start_w.get(k, float("nan")) |
|
|
ew = end_w.get(k, float("nan")) |
|
|
rows.append({"Rule": k, "StartWeight": sw, "EndWeight": ew, "Delta": (ew - sw) if (pd.notna(sw) and pd.notna(ew)) else float("nan")}) |
|
|
df = pd.DataFrame(rows) |
|
|
|
|
|
return df.sort_values("Delta", key=lambda s: s.abs(), ascending=False).reset_index(drop=True) |
|
|
|
|
|
WEIGHTS_TABLE = weights_comparison_df(start_rules_to_weights, end_rules_to_weights) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="PSL Minimal-Circuit • Static") as demo: |
|
|
gr.Markdown("## PSL Minimal-Circuit • Static Overview") |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.Tab("Rules & Model"): |
|
|
|
|
|
gr.Markdown(COMMENTARY_MD) |
|
|
gr.Markdown(MODEL_SUMMARY_MD) |
|
|
|
|
|
with gr.Tab("Dependency graph"): |
|
|
gr.Image(value=GRAPH_PATH, label="Rule Graph (green = positive head, red dashed = negated head)", show_label=False) |
|
|
|
|
|
with gr.Tab("Atoms"): |
|
|
ATOMS_COMMENTARY = ( |
|
|
"### Facts/Atoms Overview\n" |
|
|
"- **Observations** are known inputs (fixed evidence).\n" |
|
|
"- **Targets** are latent atoms the model will **infer** " |
|
|
"(e.g., whether `C1`, `C2`, `C3` are `Minimal`).\n" |
|
|
"- **Truth** (if provided) is held-out gold used to **evaluate** performance; " |
|
|
"it is not used during inference.\n" |
|
|
"_Note: the same ground atom must not appear in both Observations and Targets._" |
|
|
) |
|
|
gr.Markdown(ATOMS_COMMENTARY) |
|
|
|
|
|
|
|
|
def _atoms_summary_table(atoms=ATOMS): |
|
|
def fmt_atoms(pname, df): |
|
|
if df is None or len(df) == 0: |
|
|
return "" |
|
|
var_cols = [c for c in df.columns if c != "VALUE"] |
|
|
lines = [] |
|
|
for _, r in df.iterrows(): |
|
|
args = ", ".join(str(r[c]) for c in var_cols) |
|
|
val = r["VALUE"] if "VALUE" in df.columns else "" |
|
|
lines.append(f"{pname}({args}) = {val}") |
|
|
return "\n".join(lines) |
|
|
|
|
|
rows = [] |
|
|
for pname, parts in atoms.items(): |
|
|
rows.append({ |
|
|
"Predicate": pname, |
|
|
"Observations": fmt_atoms(pname, parts.get("OBS", pd.DataFrame())), |
|
|
"Targets": fmt_atoms(pname, parts.get("TARGETS", pd.DataFrame())), |
|
|
"Truth": fmt_atoms(pname, parts.get("TRUTH", pd.DataFrame())), |
|
|
}) |
|
|
return pd.DataFrame(rows) |
|
|
|
|
|
gr.Dataframe( |
|
|
value=_atoms_summary_table(ATOMS), |
|
|
label="Atoms Summary (counts)", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
with gr.Tab("Learning & Inference"): |
|
|
gr.Markdown("### Training & Inference Results") |
|
|
gr.Markdown( |
|
|
"Below are the **inferred minimality scores** for each circuit `C`, " |
|
|
"and a comparison of **rule weights** before and after `model.learn()`." |
|
|
) |
|
|
gr.Markdown("#### Inferred Minimality: `Minimal(C)`") |
|
|
gr.Dataframe(value=minimal_results_display, interactive=False) |
|
|
|
|
|
gr.Markdown("#### Rule Weights (Before vs After Learning)") |
|
|
gr.Dataframe(value=WEIGHTS_TABLE, interactive=False) |
|
|
|
|
|
gr.Markdown( |
|
|
"_Notes:_ Learning adjusts **soft rule** weights to better explain the provided TRUTH. " |
|
|
"Hard constraints remain fixed. Inference then computes truth values in `[0,1]` for target atoms." |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|