Source code for aurora_swarm.patterns.pipeline
"""Pattern 5 — Pipeline (Multi-Stage DAG).
Defines a sequence of stages, each served by a pool of agents. The
output of one stage flows as input to the next.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Callable
from aurora_swarm.pool import AgentPool, Response
[docs]
@dataclass
class Stage:
"""One step of a pipeline.
Attributes
----------
name:
Human-readable label for the stage.
prompt_template:
Must contain ``{input}`` which is replaced with the previous
stage's output (or the initial input for the first stage).
n_agents:
How many agents this stage should use.
output_transform:
``f(responses) -> Any`` — reshapes the list of responses into
a single value to feed the next stage. If ``None``, responses
are concatenated with newlines.
output_filter:
``f(response) -> bool`` — drops responses that return ``False``
before the transform step.
"""
name: str
prompt_template: str
n_agents: int
output_transform: Callable[[list[Response]], Any] | None = None
output_filter: Callable[[Response], bool] | None = None
def _default_transform(responses: list[Response]) -> str:
"""Concatenate successful response texts."""
return "\n".join(r.text for r in responses if r.success)
[docs]
async def run_pipeline(
pool: AgentPool,
stages: list[Stage],
initial_input: Any,
reuse_agents: bool = True,
) -> Any:
"""Execute stages sequentially; the output of each flows to the next.
Parameters
----------
pool:
The full agent pool.
stages:
Ordered list of pipeline stages.
initial_input:
Value substituted into ``{input}`` for the first stage.
reuse_agents:
If ``True`` all stages draw agents from the same pool (up to
``n_agents``). If ``False`` the pool is partitioned so each
stage receives a dedicated, non-overlapping subset.
Returns
-------
Any
The transformed output of the final stage.
"""
current_input = initial_input
offset = 0 # used when partitioning
for stage in stages:
# select agents for this stage
if reuse_agents:
stage_pool = pool.select(list(range(min(stage.n_agents, pool.size))))
else:
end = min(offset + stage.n_agents, pool.size)
stage_pool = pool.slice(offset, end)
offset = end
# build prompts
prompt = stage.prompt_template.replace("{input}", str(current_input))
# broadcast the same prompt to all agents in this stage
responses = await stage_pool.broadcast_prompt(prompt)
# optional filter
if stage.output_filter is not None:
responses = [r for r in responses if stage.output_filter(r)]
# transform
transform = stage.output_transform or _default_transform
current_input = transform(responses)
return current_input
async def fan_out_fan_in(
pool: AgentPool,
prompt: str,
collect_prompt: str,
n_workers: int | None = None,
) -> Response:
"""Convenience two-stage pipeline: broadcast then collect.
Parameters
----------
pool:
Agent pool.
prompt:
Sent to all workers.
collect_prompt:
Template with ``{responses}`` placeholder for the collector.
n_workers:
How many workers to use (default: all).
"""
if n_workers is not None:
worker_pool = pool.select(list(range(min(n_workers, pool.size))))
else:
worker_pool = pool
responses = await worker_pool.broadcast_prompt(prompt)
combined = "\n---\n".join(r.text for r in responses if r.success)
filled = collect_prompt.replace("{responses}", combined)
# Use aggregation preset for collector (larger prompts)
max_tokens = getattr(pool, "_max_tokens_aggregation", None)
return await pool.post(0, filled, max_tokens=max_tokens)