Source code for aurora_swarm.embedding_pool

"""EmbeddingPool — async pool for OpenAI-compatible /v1/embeddings endpoints.

Provides scatter-gather over embedding servers with the same hostfile/selector
API as AgentPool, so parse_hostfile + by_tag work the same way.
"""

from __future__ import annotations

import asyncio
import random
from dataclasses import dataclass
from typing import TYPE_CHECKING, Sequence

from aurora_swarm.hostfile import AgentEndpoint

if TYPE_CHECKING:
    from openai import AsyncOpenAI


# ---------------------------------------------------------------------------
# EmbeddingResponse
# ---------------------------------------------------------------------------

[docs] @dataclass class EmbeddingResponse: """Result of a single embedding request.""" success: bool embedding: list[float] | None error: str | None = None agent_index: int = -1
# --------------------------------------------------------------------------- # EmbeddingPool # --------------------------------------------------------------------------- def _resolve_endpoints( endpoints: Sequence[AgentEndpoint | tuple[str, int]], ) -> list[AgentEndpoint]: """Normalize to list of AgentEndpoint (same as AgentPool).""" out: list[AgentEndpoint] = [] for ep in endpoints: if isinstance(ep, AgentEndpoint): out.append(ep) else: host, port = ep out.append(AgentEndpoint(host=host, port=port)) return out
[docs] class EmbeddingPool: """Async pool of embedding endpoints (OpenAI-compatible /v1/embeddings). Parameters ---------- endpoints: Embedding endpoints — either :class:`AgentEndpoint` objects or ``(host, port)`` tuples (tags will be empty). model: Embedding model id (e.g. sentence-transformers/all-MiniLM-L6-v2). concurrency: Maximum number of in-flight requests (asyncio semaphore size). timeout: Per-request timeout in seconds. """ def __init__( self, endpoints: Sequence[AgentEndpoint | tuple[str, int]], model: str, concurrency: int = 512, timeout: float = 60.0, ) -> None: self._endpoints = _resolve_endpoints(endpoints) self._model = model self._concurrency = concurrency self._timeout = timeout self._semaphore = asyncio.Semaphore(concurrency) self._clients: list[AsyncOpenAI] | None = None async def _get_clients(self) -> list[AsyncOpenAI]: """Create and cache one AsyncOpenAI per endpoint.""" if self._clients is not None: return self._clients from openai import AsyncOpenAI self._clients = [ AsyncOpenAI( base_url=f"{ep.url}/v1", api_key="EMPTY", timeout=self._timeout, ) for ep in self._endpoints ] return self._clients @property def size(self) -> int: """Number of endpoints in the pool.""" return len(self._endpoints) @property def timeout(self) -> float: """Per-request timeout in seconds.""" return self._timeout @property def endpoints(self) -> list[AgentEndpoint]: return list(self._endpoints)
[docs] def by_tag(self, key: str, value: str) -> EmbeddingPool: """Return a sub-pool of endpoints whose tag *key* equals *value*.""" filtered = [ep for ep in self._endpoints if ep.tags.get(key) == value] return self._sub_pool(filtered)
[docs] def sample(self, n: int) -> EmbeddingPool: """Return a sub-pool of *n* randomly chosen endpoints.""" chosen = random.sample(self._endpoints, min(n, len(self._endpoints))) return self._sub_pool(chosen)
[docs] def select(self, indices: Sequence[int]) -> EmbeddingPool: """Return a sub-pool with endpoints at the given indices.""" selected = [self._endpoints[i] for i in indices] return self._sub_pool(selected)
[docs] def slice(self, start: int, stop: int) -> EmbeddingPool: """Return a sub-pool from index *start* to *stop*.""" return self._sub_pool(self._endpoints[start:stop])
def _sub_pool(self, endpoints: list[AgentEndpoint]) -> EmbeddingPool: """Create a child pool with filtered endpoints (own clients when used).""" child = EmbeddingPool.__new__(EmbeddingPool) child._endpoints = endpoints child._model = self._model child._concurrency = self._concurrency child._timeout = self._timeout child._semaphore = self._semaphore child._clients = None return child
[docs] async def embed_one(self, agent_index: int, text: str) -> EmbeddingResponse: """Request embedding for *text* from the endpoint at *agent_index*.""" if not self._endpoints or agent_index < 0 or agent_index >= len(self._endpoints): return EmbeddingResponse( success=False, embedding=None, error="Invalid agent_index", agent_index=agent_index, ) clients = await self._get_clients() async with self._semaphore: try: r = await clients[agent_index].embeddings.create( model=self._model, input=[text], ) if r.data and len(r.data) > 0: return EmbeddingResponse( success=True, embedding=list(r.data[0].embedding), agent_index=agent_index, ) return EmbeddingResponse( success=False, embedding=None, error="Empty response data", agent_index=agent_index, ) except Exception as exc: return EmbeddingResponse( success=False, embedding=None, error=str(exc), agent_index=agent_index, )
[docs] async def embed_all(self, texts: list[str]) -> list[EmbeddingResponse]: """Scatter *texts* across endpoints round-robin; return responses in input order.""" n = self.size if n == 0: return [ EmbeddingResponse(success=False, embedding=None, error="Empty pool", agent_index=-1) for _ in texts ] tasks = [ self.embed_one(i % n, text) for i, text in enumerate(texts) ] results = await asyncio.gather(*tasks) return list(results)
[docs] async def close(self) -> None: """Release resources. AsyncOpenAI clients do not require explicit close.""" self._clients = None
async def __aenter__(self) -> EmbeddingPool: return self async def __aexit__(self, *exc: object) -> None: await self.close()