-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsemantics.py
More file actions
158 lines (132 loc) · 5.17 KB
/
semantics.py
File metadata and controls
158 lines (132 loc) · 5.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
from memory_engine.schema import MemoryEdge, MemoryNode
class SemanticRole(str, Enum):
OBLIGATION = "obligation"
CONDITION = "condition"
EXCEPTION = "exception"
REMEDY = "remedy"
ESCALATION = "escalation"
ACTION = "action"
@dataclass(frozen=True, slots=True)
class ExceptionLink:
source_node_id: str
target_node_id: str
edge_type: str = "exception_to"
explanation: str = ""
@dataclass(frozen=True, slots=True)
class ContradictionCandidate:
left_node_id: str
right_node_id: str
explanation: str
@dataclass(frozen=True, slots=True)
class SemanticScoreSignals:
exception_score: float = 0.0
contradiction_score: float = 0.0
def infer_semantic_role(text: str, *, node_type: str) -> SemanticRole:
lowered = text.lower()
if any(keyword in lowered for keyword in ("unless", "except", "notwithstanding")):
return SemanticRole.EXCEPTION
if any(keyword in lowered for keyword in ("terminate", "damages", "withhold", "recover")):
return SemanticRole.REMEDY
if any(keyword in lowered for keyword in ("escalate", "page")):
return SemanticRole.ESCALATION
if any(keyword in lowered for keyword in ("restart", "roll back", "verify", "notify")):
return SemanticRole.ACTION
if any(keyword in lowered for keyword in ("if", "when", "after", "once", "subject to")):
return SemanticRole.CONDITION
if node_type in {"clause", "step"} and any(keyword in lowered for keyword in ("shall", "must")):
return SemanticRole.OBLIGATION
return SemanticRole.ACTION if node_type == "step" else SemanticRole.OBLIGATION
def semantic_activation_bonus(node: MemoryNode) -> float:
role_name = node.attributes.get("semantic_role")
if role_name == SemanticRole.EXCEPTION.value:
return 0.15
if role_name in {SemanticRole.REMEDY.value, SemanticRole.ESCALATION.value}:
return 0.08
return 0.0
def contradiction_candidates(nodes: list[MemoryNode], edges: list[MemoryEdge]) -> list[ContradictionCandidate]:
candidates: list[ContradictionCandidate] = []
node_map = {node.id: node for node in nodes}
for edge in edges:
if edge.edge_type != "exception_to":
continue
source = node_map.get(edge.from_id)
target = node_map.get(edge.to_id)
if source is None or target is None:
continue
if source.attributes.get("semantic_role") == SemanticRole.EXCEPTION.value:
candidates.append(
ContradictionCandidate(
left_node_id=target.id,
right_node_id=source.id,
explanation="exception link may override the general rule",
)
)
return candidates
def semantic_score_signals(
node: MemoryNode,
*,
source_node_id: str | None = None,
) -> SemanticScoreSignals:
role_name = node.attributes.get("semantic_role")
exception_score = 0.0
if role_name == SemanticRole.EXCEPTION.value:
exception_score = 1.0
elif role_name in {SemanticRole.REMEDY.value, SemanticRole.ESCALATION.value}:
exception_score = 0.45
contradiction_targets = set(node.attributes.get("contradiction_targets", []))
contradiction_score = 0.0
if contradiction_targets:
contradiction_score = 0.45
if source_node_id is not None and source_node_id in contradiction_targets:
contradiction_score = 1.0
return SemanticScoreSignals(
exception_score=exception_score,
contradiction_score=contradiction_score,
)
def query_role_alignment_score(query: str, node: MemoryNode) -> float:
lowered = query.lower()
role_name = node.attributes.get("semantic_role")
if role_name == SemanticRole.ESCALATION.value and any(
token in lowered for token in ("escalate", "escalation", "page", "who should")
):
return 1.0
return 0.0
def contradiction_bonus(
*,
node_id: str,
candidates: list[ContradictionCandidate],
source_node_id: str | None = None,
) -> float:
contradiction_targets = set()
for candidate in candidates:
pair = {candidate.left_node_id, candidate.right_node_id}
if node_id not in pair:
continue
contradiction_targets.update(pair - {node_id})
probe_node = MemoryNode(
id=node_id,
type="semantic_probe",
content="",
attributes={"contradiction_targets": sorted(contradiction_targets)},
)
score = semantic_score_signals(
probe_node,
source_node_id=source_node_id,
).contradiction_score
return 0.14 if score >= 1.0 else 0.06 if score > 0 else 0.0
def surfaced_contradictions(
returned_node_ids: list[str] | set[str],
candidates: list[ContradictionCandidate],
) -> list[tuple[str, str]]:
returned_node_id_set = set(returned_node_ids)
surfaced: list[tuple[str, str]] = []
for candidate in candidates:
if (
candidate.left_node_id in returned_node_id_set
and candidate.right_node_id in returned_node_id_set
):
surfaced.append((candidate.left_node_id, candidate.right_node_id))
return surfaced