Source code for aurora_swarm.pool

"""AgentPool — async connection pool for Aurora agent endpoints.

Provides semaphore-throttled, pooled HTTP access to 1000–4000 LLM agent
instances.  Every public pattern function in this package takes an
``AgentPool`` as its first argument.
"""

from __future__ import annotations

import asyncio
import random
from dataclasses import dataclass
from typing import Sequence

import aiohttp

from aurora_swarm.hostfile import AgentEndpoint


# ---------------------------------------------------------------------------
# Response dataclass
# ---------------------------------------------------------------------------

[docs] @dataclass class Response: """Result of a single agent call.""" success: bool text: str error: str | None = None agent_index: int = -1
# --------------------------------------------------------------------------- # AgentPool # ---------------------------------------------------------------------------
[docs] class AgentPool: """Async pool of agent HTTP endpoints with concurrency control. Parameters ---------- endpoints: Agent endpoints — either :class:`AgentEndpoint` objects or plain ``(host, port)`` tuples (tags will be empty). concurrency: Maximum number of in-flight requests (asyncio semaphore size). connector_limit: Maximum number of TCP connections in the aiohttp pool. timeout: Per-request timeout in seconds. """ def __init__( self, endpoints: Sequence[AgentEndpoint | tuple[str, int]], concurrency: int = 512, connector_limit: int = 1024, timeout: float = 120.0, ) -> None: self._endpoints: list[AgentEndpoint] = [] for ep in endpoints: if isinstance(ep, AgentEndpoint): self._endpoints.append(ep) else: host, port = ep self._endpoints.append(AgentEndpoint(host=host, port=port)) self._concurrency = concurrency self._connector_limit = connector_limit self._timeout = timeout self._semaphore = asyncio.Semaphore(concurrency) self._session: aiohttp.ClientSession | None = None # -- lifecycle ----------------------------------------------------------- async def _get_session(self) -> aiohttp.ClientSession: if self._session is None or self._session.closed: connector = aiohttp.TCPConnector(limit=self._connector_limit) self._session = aiohttp.ClientSession(connector=connector) return self._session
[docs] async def close(self) -> None: if self._session and not self._session.closed: await self._session.close()
async def __aenter__(self) -> "AgentPool": await self._get_session() # Eagerly create session so parent owns it; sub-pools share it and close() on exit closes it return self async def __aexit__(self, *exc: object) -> None: await self.close() # -- properties ---------------------------------------------------------- @property def size(self) -> int: """Number of agents in the pool.""" return len(self._endpoints) @property def timeout(self) -> float: """Base per-request timeout in seconds.""" return self._timeout @property def endpoints(self) -> list[AgentEndpoint]: return list(self._endpoints) # -- selectors -----------------------------------------------------------
[docs] def by_tag(self, key: str, value: str) -> "AgentPool": """Return a sub-pool of agents whose tag *key* equals *value*.""" filtered = [ep for ep in self._endpoints if ep.tags.get(key) == value] return self._sub_pool(filtered)
[docs] def sample(self, n: int) -> "AgentPool": """Return a sub-pool of *n* randomly chosen agents.""" chosen = random.sample(self._endpoints, min(n, len(self._endpoints))) return self._sub_pool(chosen)
[docs] def select(self, indices: Sequence[int]) -> "AgentPool": """Return a sub-pool with agents at the given indices.""" selected = [self._endpoints[i] for i in indices] return self._sub_pool(selected)
[docs] def slice(self, start: int, stop: int) -> "AgentPool": """Return a sub-pool from index *start* to *stop*.""" return self._sub_pool(self._endpoints[start:stop])
def _sub_pool(self, endpoints: list[AgentEndpoint]) -> "AgentPool": """Create a child pool sharing concurrency settings.""" child = AgentPool.__new__(AgentPool) child._endpoints = endpoints child._concurrency = self._concurrency child._connector_limit = self._connector_limit child._timeout = self._timeout child._semaphore = self._semaphore # share parent semaphore child._session = self._session # share parent session return child # -- core request --------------------------------------------------------
[docs] async def post(self, agent_index: int, prompt: str, max_tokens: int | None = None) -> Response: """Send *prompt* to the agent at *agent_index* and return its response. The call is throttled by the pool-wide semaphore so that at most ``concurrency`` requests are in flight at once. Parameters ---------- agent_index: Index of the agent to send the prompt to. prompt: The prompt text. max_tokens: Optional maximum tokens to generate. Ignored by base AgentPool (only used by VLLMPool and subclasses). """ ep = self._endpoints[agent_index] session = await self._get_session() async with self._semaphore: try: async with session.post( f"{ep.url}/generate", json={"prompt": prompt}, timeout=aiohttp.ClientTimeout(total=self._timeout), ) as resp: data = await resp.json() return Response( success=True, text=data.get("response", data.get("text", "")), agent_index=agent_index, ) except Exception as exc: return Response( success=False, text="", error=str(exc), agent_index=agent_index, )
[docs] async def send_all(self, prompts: list[str]) -> list[Response]: """Send ``prompts[i]`` to ``agent[i % size]`` concurrently. Returns responses in *input* order (i.e. ``results[i]`` corresponds to ``prompts[i]``). """ tasks = [ self.post(i % self.size, prompt) for i, prompt in enumerate(prompts) ] return list(await asyncio.gather(*tasks))
[docs] async def send_all_batched(self, prompts: list[str], max_tokens: int | None = None) -> list[Response]: """Send prompts with batching if supported, otherwise use send_all. Default implementation for base AgentPool — just delegates to send_all. VLLMPool overrides this to use batch API for efficiency. Parameters ---------- prompts: List of prompts to send. max_tokens: Optional max tokens override (ignored in base implementation). Returns ------- list[Response] Responses in the same order as input prompts. """ return await self.send_all(prompts)
[docs] async def broadcast_prompt(self, prompt: str) -> list[Response]: """Send the same *prompt* to every agent in the pool.""" tasks = [self.post(i, prompt) for i in range(self.size)] return list(await asyncio.gather(*tasks))