"""VLLMPool — AgentPool subclass for vLLM OpenAI-compatible endpoints.
vLLM exposes an OpenAI-compatible chat completions API at
``/v1/chat/completions``. This pool overrides :meth:`post` to speak
that protocol instead of the simpler ``/generate`` endpoint used by
the base :class:`AgentPool`.
"""
from __future__ import annotations
import asyncio
import math
import os
import aiohttp
from openai import AsyncOpenAI
# Retry on connection/timeout errors (transient when many concurrent connections open)
def _is_retryable_connection_error(exc: BaseException) -> bool:
name = type(exc).__name__
msg = str(exc).lower()
return (
name
in (
"APIConnectionError",
"APITimeoutError",
"ConnectionError",
"TimeoutError",
)
or "connection" in msg
or "timeout" in msg
)
from aurora_swarm.hostfile import AgentEndpoint
from aurora_swarm.pool import AgentPool, Response
[docs]
class VLLMPool(AgentPool):
"""Agent pool that communicates via vLLM's OpenAI-compatible API.
Parameters
----------
endpoints:
Agent endpoints (host + port where vLLM is listening).
model:
Model identifier passed in the ``"model"`` field of every
request (e.g. ``"openai/gpt-oss-120b"``).
max_tokens:
Maximum tokens to generate per request (default context).
Can be overridden via ``AURORA_SWARM_MAX_TOKENS`` env var.
max_tokens_aggregation:
Maximum tokens for aggregation/reduce steps (larger prompts).
Can be overridden via ``AURORA_SWARM_MAX_TOKENS_AGGREGATION`` env var.
Defaults to 2 * max_tokens if not specified.
model_max_context:
Model's maximum context length. If None, will be fetched from
vLLM's ``/v1/models`` endpoint on first request. Can be overridden
via ``AURORA_SWARM_MODEL_MAX_CONTEXT`` env var.
buffer:
Safety margin (in tokens) for dynamic sizing to account for
reasoning overhead. Defaults to 512.
use_batch:
If True, use batch prompting via the completions API for
send_all_batched. If False, fall back to individual requests.
Defaults to True.
concurrency:
Maximum number of in-flight requests.
connector_limit:
Maximum TCP connections in the aiohttp pool.
timeout:
Base per-request timeout in seconds. Single requests use this;
batch requests use max(timeout, scaled) where scaled depends on batch size.
batch_concurrency:
vLLM's max concurrent sequences (waves). Used to scale batch timeout;
default 256.
timeout_per_sequence:
Estimated seconds per sequence for batch timeout scaling. Can be set via
AURORA_SWARM_TIMEOUT_PER_SEQUENCE. Default 60.0.
batch_timeout_cap:
If set, cap the computed batch timeout so one huge batch does not get
an extreme value. Optional.
"""
def __init__(
self,
endpoints: list[AgentEndpoint],
model: str = "openai/gpt-oss-120b",
max_tokens: int | None = None,
max_tokens_aggregation: int | None = None,
model_max_context: int | None = None,
buffer: int = 512,
use_batch: bool = True,
concurrency: int = 512,
connector_limit: int = 1024,
timeout: float = 300.0,
batch_concurrency: int = 256,
timeout_per_sequence: float | None = None,
batch_timeout_cap: float | None = None,
) -> None:
super().__init__(
endpoints,
concurrency=concurrency,
connector_limit=connector_limit,
timeout=timeout,
)
self._model = model
self._use_batch = use_batch
self._batch_concurrency = batch_concurrency
self._timeout_per_sequence = (
timeout_per_sequence
if timeout_per_sequence is not None
else float(os.environ.get("AURORA_SWARM_TIMEOUT_PER_SEQUENCE", "60.0"))
)
self._batch_timeout_cap = batch_timeout_cap
# Load from environment with fallbacks
self._max_tokens = (
max_tokens
or int(os.environ.get("AURORA_SWARM_MAX_TOKENS", "512"))
)
self._max_tokens_aggregation = (
max_tokens_aggregation
or int(os.environ.get("AURORA_SWARM_MAX_TOKENS_AGGREGATION", str(self._max_tokens * 2)))
)
self._model_max_context = (
model_max_context
or (int(os.environ["AURORA_SWARM_MODEL_MAX_CONTEXT"]) if "AURORA_SWARM_MODEL_MAX_CONTEXT" in os.environ else None)
)
self._buffer = buffer
self._model_max_context_cached: int | None = None
# Create OpenAI clients for each endpoint (for batch requests)
self._openai_clients: dict[int, AsyncOpenAI] = {}
for i, ep in enumerate(self._endpoints):
self._openai_clients[i] = AsyncOpenAI(
base_url=f"{ep.url}/v1",
api_key="EMPTY", # vLLM convention
timeout=timeout,
)
# -- model metadata -------------------------------------------------------
async def _get_model_max_context(self) -> int:
"""Fetch the model's max context length from vLLM /v1/models endpoint.
Cached after first call. Returns a sensible default if fetch fails.
"""
# Return cached value if available
if self._model_max_context_cached is not None:
return self._model_max_context_cached
# Return explicitly configured value
if self._model_max_context is not None:
self._model_max_context_cached = self._model_max_context
return self._model_max_context
# Fetch from vLLM API
try:
ep = self._endpoints[0]
session = await self._get_session()
async with session.get(
f"{ep.url}/v1/models",
timeout=aiohttp.ClientTimeout(total=10.0),
) as resp:
data = await resp.json()
# Find our model in the list
for model_info in data.get("data", []):
if model_info.get("id") == self._model:
max_len = model_info.get("max_model_len")
if max_len:
self._model_max_context_cached = max_len
return max_len
except Exception:
pass # Fall back to default
# Default fallback (131072 is common for many models)
self._model_max_context_cached = 131072
return self._model_max_context_cached
# -- core request (OpenAI chat completions) ------------------------------
[docs]
async def post(self, agent_index: int, prompt: str, max_tokens: int | None = None) -> Response:
"""Send *prompt* via the OpenAI chat-completions API on the agent.
The prompt is wrapped as a single ``user`` message.
Parameters
----------
agent_index:
Index of the agent to send the prompt to.
prompt:
The prompt text.
max_tokens:
Optional override for max tokens. If None, uses dynamic sizing
based on prompt length and model context limit.
"""
ep = self._endpoints[agent_index]
session = await self._get_session()
# Compute max_tokens dynamically if not explicitly provided
if max_tokens is None:
# Get model's max context length
model_max = await self._get_model_max_context()
# Estimate prompt tokens (rough heuristic: 1 token ≈ 4 chars)
prompt_est = len(prompt) // 4
# Dynamic sizing: never exceed model capacity
# Use default max_tokens as the preferred cap
tokens = min(
self._max_tokens,
max(128, model_max - prompt_est - self._buffer)
)
else:
tokens = max_tokens
async with self._semaphore:
try:
async with session.post(
f"{ep.url}/v1/chat/completions",
json={
"model": self._model,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": tokens,
},
timeout=aiohttp.ClientTimeout(total=self._timeout),
) as resp:
data = await resp.json()
# Check for error response
if resp.status != 200:
error_msg = data.get("error", {}).get("message", f"HTTP {resp.status}")
return Response(
success=False,
text="",
error=f"API error: {error_msg}",
agent_index=agent_index,
)
# Check for expected response structure
if "choices" not in data or not data["choices"]:
return Response(
success=False,
text="",
error=f"Invalid response structure: {list(data.keys())}",
agent_index=agent_index,
)
message = data["choices"][0]["message"]
text = message.get("content") or message.get("reasoning_content") or ""
return Response(
success=True,
text=text,
agent_index=agent_index,
)
except Exception as exc:
return Response(
success=False,
text="",
error=f"{type(exc).__name__}: {str(exc)}",
agent_index=agent_index,
)
[docs]
async def post_batch(
self,
agent_index: int,
prompts: list[str],
max_tokens: int | None = None,
) -> list[Response]:
"""Send multiple prompts to one agent via the completions API.
Uses the OpenAI completions endpoint which supports batch prompts
(a list of strings). This reduces N HTTP requests to 1.
Parameters
----------
agent_index:
Index of the agent to send prompts to.
prompts:
List of prompts to send in one batch.
max_tokens:
Optional override for max tokens. If None, uses dynamic sizing
based on average prompt length.
Returns
-------
list[Response]
One Response per prompt, in the same order as the input.
"""
if not prompts:
return []
ep = self._endpoints[agent_index]
client = self._openai_clients[agent_index]
# Compute max_tokens dynamically if not explicitly provided
if max_tokens is None:
# Get model's max context length
model_max = await self._get_model_max_context()
# Estimate tokens based on average prompt length
avg_prompt_len = sum(len(p) for p in prompts) // len(prompts)
prompt_est = avg_prompt_len // 4
# Dynamic sizing: never exceed model capacity
tokens = min(
self._max_tokens,
max(128, model_max - prompt_est - self._buffer)
)
else:
tokens = max_tokens
# Batch-size-dependent timeout: scale with number of waves
n = len(prompts)
waves = max(1, math.ceil(n / self._batch_concurrency))
scaled = waves * self._timeout_per_sequence
effective_timeout = max(self._timeout, scaled)
if self._batch_timeout_cap is not None:
effective_timeout = min(effective_timeout, self._batch_timeout_cap)
async with self._semaphore:
last_exc: BaseException | None = None
for attempt in range(6): # 6 attempts total (0..5)
try:
# Call completions API with batch prompts
response = await client.completions.create(
model=self._model,
prompt=prompts,
max_tokens=tokens,
timeout=effective_timeout,
)
# Map choices to Response objects
results: list[Response] = []
for i, choice in enumerate(response.choices):
results.append(
Response(
success=True,
text=choice.text,
agent_index=agent_index,
)
)
return results
except Exception as exc:
last_exc = exc
if attempt < 5 and _is_retryable_connection_error(exc):
await asyncio.sleep(2**attempt)
continue
# On error, return failed Response for each prompt
return [
Response(
success=False,
text="",
error=f"{type(exc).__name__}: {str(exc)}",
agent_index=agent_index,
)
for _ in prompts
]
# Should not reach here; if we do, treat last_exc as final failure
assert last_exc is not None
return [
Response(
success=False,
text="",
error=f"{type(last_exc).__name__}: {str(last_exc)}",
agent_index=agent_index,
)
for _ in prompts
]
[docs]
async def send_all_batched(
self,
prompts: list[str],
max_tokens: int | None = None,
) -> list[Response]:
"""Send prompts using batch API, grouping by target agent.
Groups prompts by their target agent (round-robin based on index),
then sends one batched request per agent. Reconstructs results in
input order.
Parameters
----------
prompts:
List of prompts to send.
max_tokens:
Optional max tokens override.
Returns
-------
list[Response]
Responses in the same order as input prompts.
"""
if not self._use_batch or not prompts:
# Fall back to non-batched send_all
return await self.send_all(prompts)
# Group prompts by target agent
groups: dict[int, list[tuple[int, str]]] = {}
for i, prompt in enumerate(prompts):
agent_idx = i % self.size
if agent_idx not in groups:
groups[agent_idx] = []
groups[agent_idx].append((i, prompt))
# Send batched requests concurrently
async def send_group(agent_idx: int, items: list[tuple[int, str]]) -> list[tuple[int, Response]]:
"""Send batch to one agent, return (original_index, response) pairs."""
group_prompts = [prompt for _, prompt in items]
responses = await self.post_batch(agent_idx, group_prompts, max_tokens)
return [(items[j][0], responses[j]) for j in range(len(items))]
tasks = [send_group(agent_idx, items) for agent_idx, items in groups.items()]
all_results = await asyncio.gather(*tasks)
# Flatten and sort by original index
indexed_responses: list[tuple[int, Response]] = []
for result_group in all_results:
indexed_responses.extend(result_group)
indexed_responses.sort(key=lambda x: x[0])
# Extract responses in order
return [resp for _, resp in indexed_responses]
# -- sub-pool override ---------------------------------------------------
def _sub_pool(self, endpoints: list[AgentEndpoint]) -> "VLLMPool":
"""Create a child VLLMPool sharing concurrency settings."""
child = VLLMPool.__new__(VLLMPool)
child._endpoints = endpoints
child._concurrency = self._concurrency
child._connector_limit = self._connector_limit
child._timeout = self._timeout
child._batch_concurrency = self._batch_concurrency
child._timeout_per_sequence = self._timeout_per_sequence
child._batch_timeout_cap = self._batch_timeout_cap
child._semaphore = self._semaphore
child._session = self._session
child._model = self._model
child._use_batch = self._use_batch
child._max_tokens = self._max_tokens
child._max_tokens_aggregation = self._max_tokens_aggregation
child._model_max_context = self._model_max_context
child._buffer = self._buffer
child._model_max_context_cached = self._model_max_context_cached
child._openai_clients = self._openai_clients # Share clients
return child