Class Mapping for an agentic ReAct framework
class ReActController:
def run(self, query: str, context: list[dict]) -> dict: ...
class ReActState:
query: str
thoughts: list[str]
actions: list[str]
observations: list[dict]
retrieved_evidence: list[RetrievedEvidence]
done: bool
class ToolRouter:
def call(self, action_name: str, params: dict) -> dict: ...
class FinalAnswerWriter:
def write(self, state: ReActState) -> dict: ...
This keeps the ReAct behavior separate from ingestion, indexing, and vision enrichment, which is the cleanest way to make the agent auditable and testable.
Skeleton code for applying it to ReAct framework:
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Protocol, Literal
@dataclass
class FrameEvidence:
video_id: str
tour_id: str
frame_id: str
timestamp_ms: int
tour_order: int
image_uri: str
caption: str = ""
objects: list[str] = field(default_factory=list)
ocr_text: str = ""
spatial_relations: list[str] = field(default_factory=list)
scene_type: str = ""
change_note: str = ""
confidence: float = 0.0
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class RetrievedEvidence:
evidence: FrameEvidence
score: float
matched_query: str = ""
rank_reason: str = ""
@dataclass
class ReActStep:
thought: str = ""
action_name: str = ""
action_args: dict[str, Any] = field(default_factory=dict)
observation: dict[str, Any] = field(default_factory=dict)
@dataclass
class ReActState:
query: str
steps: list[ReActStep] = field(default_factory=list)
retrieved: list[RetrievedEvidence] = field(default_factory=list)
max_steps: int = 3
done: bool = False
@dataclass
class QueryPlan:
original_query: str
subqueries: list[str]
intents: list[str]
needs_temporal: bool = False
needs_spatial: bool = False
needs_object: bool = True
class Retriever(Protocol):
def retrieve(self, subquery: str, top_k: int = 5) -> list[RetrievedEvidence]: ...
class ToolRouter:
def __init__(self, retriever: Retriever, vision_tools: Any = None) -> None:
self.retriever = retriever
self.vision_tools = vision_tools
def call(self, action_name: str, params: dict[str, Any]) -> dict[str, Any]:
if action_name == "retrieve":
subquery = params["subquery"]
top_k = params.get("top_k", 5)
results = self.retriever.retrieve(subquery=subquery, top_k=top_k)
return {"results": results}
if action_name == "compare_neighbors":
frame_id = params["frame_id"]
return {"comparison": f"neighbor comparison for {frame_id}"}
if action_name == "inspect_frame":
frame_id = params["frame_id"]
return {"inspection": f"vision inspection for {frame_id}"}
raise ValueError(f"Unknown action: {action_name}")
class QueryPlanner:
def plan(self, query: str, context: list[dict[str, Any]] | None = None) -> QueryPlan:
context = context or []
q = query.lower()
needs_temporal = any(k in q for k in ["change", "before", "after", "timeline", "evolve", "progress"])
needs_spatial = any(k in q for k in ["left", "right", "near", "behind", "in front", "relative"])
needs_object = any(k in q for k in ["object", "see", "show", "visible", "present"]) or not (needs_temporal or needs_spatial)
subqueries: list[str] = []
intents: list[str] = []
if needs_object:
subqueries.append(f"{query} objects scene caption")
intents.append("object")
if needs_spatial:
subqueries.append(f"{query} spatial relations layout")
intents.append("spatial")
if needs_temporal:
subqueries.append(f"{query} timeline frame progression change")
intents.append("temporal")
if not subqueries:
subqueries = [query]
intents = ["general"]
return QueryPlan(
original_query=query,
subqueries=subqueries,
intents=intents,
needs_temporal=needs_temporal,
needs_spatial=needs_spatial,
needs_object=needs_object,
)
class ThoughtGenerator:
def next_thought(self, state: ReActState, plan: QueryPlan) -> tuple[str, str, dict[str, Any]]:
if not state.steps:
if plan.needs_temporal:
return (
"This question needs temporal evidence, so I should retrieve progression across frames.",
"retrieve",
{"subquery": plan.subqueries[0], "top_k": 5},
)
return (
"I should retrieve the most relevant evidence for the question.",
"retrieve",
{"subquery": plan.subqueries[0], "top_k": 5},
)
if plan.needs_spatial and not any("spatial" in s.observation for s in state.steps):
best = state.retrieved[0].evidence.frame_id if state.retrieved else ""
return (
"I have object evidence, but I still need spatial confirmation.",
"inspect_frame",
{"frame_id": best},
)
if plan.needs_temporal and len(state.retrieved) >= 2 and not any("comparison" in s.observation for s in state.steps):
best = state.retrieved[0].evidence.frame_id
return (
"I should compare neighboring frames to confirm change over time.",
"compare_neighbors",
{"frame_id": best},
)
return ("I have enough evidence to answer.", "finalize", {})
class SufficiencyJudge:
def is_sufficient(self, state: ReActState, plan: QueryPlan) -> bool:
if not state.retrieved:
return False
if plan.needs_temporal and len(state.retrieved) < 2:
return False
return True
class FinalAnswerWriter:
def write(self, state: ReActState, plan: QueryPlan) -> dict[str, Any]:
citations = [
{"frame_id": r.evidence.frame_id, "timestamp_ms": r.evidence.timestamp_ms, "score": r.score}
for r in state.retrieved[:5]
]
answer = {
"query": plan.original_query,
"answer": "Grounded answer synthesized from retrieved frame evidence.",
"citations": citations,
"steps": [s.__dict__ for s in state.steps],
}
return answer
class ReActController:
def __init__(
self,
planner: QueryPlanner,
router: ToolRouter,
thinker: ThoughtGenerator,
judge: SufficiencyJudge,
writer: FinalAnswerWriter,
) -> None:
self.planner = planner
self.router = router
self.thinker = thinker
self.judge = judge
self.writer = writer
def run(self, query: str, context: list[dict[str, Any]] | None = None) -> dict[str, Any]:
plan = self.planner.plan(query, context=context)
state = ReActState(query=query, max_steps=3)
for _ in range(state.max_steps):
if self.judge.is_sufficient(state, plan):
state.done = True
break
thought, action_name, action_args = self.thinker.next_thought(state, plan)
step = ReActStep(thought=thought, action_name=action_name, action_args=action_args)
if action_name == "finalize":
state.steps.append(step)
state.done = True
break
observation = self.router.call(action_name, action_args)
step.observation = observation
state.steps.append(step)
if action_name == "retrieve":
for item in observation.get("results", []):
state.retrieved.append(item)
return self.writer.write(state, plan)
No comments:
Post a Comment