Source code for aurora_swarm.patterns.tree_reduce

"""Pattern 3 — Hierarchical Tree-Reduce.

Leaf agents produce initial responses.  Groups of responses are fed to
supervisor agents that summarize them, recursively, until a single
answer remains.
"""

from __future__ import annotations

from typing import Any

from aurora_swarm.pool import AgentPool, Response


def _has_content(text: str) -> bool:
    """Return True if response text has non-whitespace content."""
    return bool((text or "").strip())


[docs] async def tree_reduce( pool: AgentPool, prompt: str, reduce_prompt: str, fanin: int = 50, items: list[Any] | None = None, ) -> Response: """Run a hierarchical tree-reduce over *pool*. Parameters ---------- pool: The agent pool (used for both leaf work and supervisors). prompt: Leaf-level task. If *items* is provided the template should contain an ``{item}`` placeholder. reduce_prompt: Supervisor summarisation task. Must contain ``{responses}`` and may contain ``{level}``. fanin: Number of responses each supervisor handles per group. items: If given, scatter items across leaf agents (one per agent, round-robin). Otherwise the same *prompt* is broadcast. """ # -- leaf phase ---------------------------------------------------------- # Use send_all (chat completions per prompt) rather than # send_all_batched (raw completions) so chat models get the # expected message format. if items is not None: leaf_prompts = [prompt.replace("{item}", str(it)) for it in items] leaf_responses = await pool.send_all(leaf_prompts) else: leaf_responses = await pool.broadcast_prompt(prompt) # -- reduction phase ----------------------------------------------------- current: list[str] = [ r.text for r in leaf_responses if r.success and _has_content(r.text) ] level = 1 while len(current) > 1: groups: list[list[str]] = [] for i in range(0, len(current), fanin): groups.append(current[i : i + fanin]) supervisor_prompts: list[str] = [] for group in groups: combined = "\n---\n".join(group) filled = reduce_prompt.replace("{responses}", combined) filled = filled.replace("{level}", str(level)) supervisor_prompts.append(filled) sup_responses = await pool.send_all(supervisor_prompts) current = [ r.text for r in sup_responses if r.success and _has_content(r.text) ] level += 1 if not current: return Response(success=False, text="", error="All agents failed during reduction") return Response(success=True, text=current[0])