Source code for acodex.thread

from __future__ import annotations

import logging
from collections.abc import AsyncIterator, Awaitable, Callable, Iterator
from typing import TYPE_CHECKING, Any, TypeVar, cast

from typing_extensions import Unpack

from acodex._internal.exec import build_exec_args
from acodex._internal.output_schema_file import create_output_schema_file
from acodex._internal.output_type import OutputTypeAdapter
from acodex._internal.thread_core import (
    build_turn_or_raise,
    initial_turn_state,
    parse_thread_event_jsonl,
    reduce_turn_state,
)
from acodex.exceptions import CodexThreadStreamNotConsumedError
from acodex.exec import AsyncCodexExec, CodexExec
from acodex.types.codex_options import CodexOptions
from acodex.types.events import (
    ThreadErrorEvent,
    ThreadEvent,
    ThreadStartedEvent,
    TurnFailedEvent,
)
from acodex.types.input import Input
from acodex.types.thread_options import ThreadOptions
from acodex.types.turn import AsyncRunStreamedResult, RunResult, RunStreamedResult
from acodex.types.turn_options import TurnOptions

if TYPE_CHECKING:
    T = TypeVar("T", default=Any)
else:
    T = TypeVar("T")


logger = logging.getLogger(__name__)


[docs] class Thread: """Represent a thread of conversation with the agent. One thread can have multiple consecutive turns. """ def __init__( self, *, exec: CodexExec, # noqa: A002 options: CodexOptions, thread_options: ThreadOptions, thread_id: str | None = None, ) -> None: self._exec = exec self._options = options self._id = thread_id self._thread_options = thread_options logger.debug("Created Thread instance (thread_id=%s)", self._id) @property def id(self) -> str | None: """Return the ID of the thread. The ID is populated after the first turn starts. """ return self._id
[docs] def run_streamed( self, input: Input, # noqa: A002 output_type: type[T] | None = None, **turn_options: Unpack[TurnOptions], ) -> RunStreamedResult[T]: """Provide input to the agent and stream turn events as they are produced. Set `turn_options["signal"]` via `event.set()` to request cancellation. Use `threading.Event` in synchronous flows. Returns: A streamed turn result with an iterator of parsed events. """ logger.info( "Starting streamed turn request (thread_id=%s, output_type=%s, output_schema=%s)", self._id, output_type is not None, turn_options.get("output_schema") is not None, ) state = initial_turn_state() stream_completed = False output_type_adapter = OutputTypeAdapter( output_type=output_type, output_schema=turn_options.get("output_schema"), ) def build_result() -> RunResult[T]: if not stream_completed: raise CodexThreadStreamNotConsumedError( "streamed.result is unavailable until streamed.events is fully consumed", ) return build_turn_or_raise(state, output_type_adapter=output_type_adapter) def event_generator() -> Iterator[ThreadEvent]: nonlocal state, stream_completed schema_file = create_output_schema_file(schema=output_type_adapter.json_schema()) line_stream: Iterator[str] | None = None try: exec_args = build_exec_args( input=input, options=self._options, thread_options=self._thread_options, thread_id=self._id, turn_options=turn_options, output_schema_path=schema_file.schema_path, ) line_stream = self._exec.run(exec_args) for line in line_stream: event = parse_thread_event_jsonl(line) if event is None: continue if isinstance(event, ThreadStartedEvent): self._id = event.thread_id logger.info("Assigned thread ID from event stream: %s", self._id) log_method = ( logger.warning if isinstance(event, (TurnFailedEvent, ThreadErrorEvent)) else logger.debug ) log_method( "Received event class=%s type=%s thread_id=%s event=%r", event.__class__.__name__, event.type, self._id, event, ) state = reduce_turn_state(state, event) yield event stream_completed = True logger.info("Completed streamed turn request (thread_id=%s)", self._id) finally: if line_stream is not None: _close_if_possible(line_stream) schema_file.cleanup() logger.debug("Cleaned up streamed turn resources (thread_id=%s)", self._id) return RunStreamedResult(events=event_generator(), result_factory=build_result)
[docs] def run( self, input: Input, # noqa: A002 output_type: type[T] | None = None, **turn_options: Unpack[TurnOptions], ) -> RunResult[T]: """Provide input to the agent and return the completed turn. Set `turn_options["signal"]` via `event.set()` to request cancellation. Use `threading.Event` in synchronous flows. Returns: The completed turn with reduced items, final response, and usage. """ logger.info("Running turn request to completion (thread_id=%s)", self._id) streamed = self.run_streamed(input, output_type=output_type, **turn_options) events = streamed.events try: for _event in events: pass finally: _close_if_possible(events) logger.info("Completed turn request (thread_id=%s)", self._id) return streamed.result
[docs] class AsyncThread: """Represent a thread of conversation with the agent. One thread can have multiple consecutive turns. """ def __init__( self, *, exec: AsyncCodexExec, # noqa: A002 options: CodexOptions, thread_options: ThreadOptions, thread_id: str | None = None, ) -> None: self._exec = exec self._options = options self._id = thread_id self._thread_options = thread_options logger.debug("Created AsyncThread instance (thread_id=%s)", self._id) @property def id(self) -> str | None: """Return the ID of the thread. The ID is populated after the first turn starts. """ return self._id
[docs] async def run_streamed( self, input: Input, # noqa: A002 output_type: type[T] | None = None, **turn_options: Unpack[TurnOptions], ) -> AsyncRunStreamedResult[T]: """Provide input to the agent and stream turn events as they are produced. Set `turn_options["signal"]` via `event.set()` to request cancellation. Use `asyncio.Event` in asynchronous flows. Returns: A streamed turn result with an async iterator of parsed events. """ logger.info( "Starting async streamed turn request (thread_id=%s, output_type=%s, output_schema=%s)", self._id, output_type is not None, turn_options.get("output_schema") is not None, ) state = initial_turn_state() stream_completed = False output_type_adapter = OutputTypeAdapter( output_type=output_type, output_schema=turn_options.get("output_schema"), ) def build_result() -> RunResult[T]: if not stream_completed: raise CodexThreadStreamNotConsumedError( "streamed.result is unavailable until streamed.events is fully consumed", ) return build_turn_or_raise(state, output_type_adapter=output_type_adapter) async def event_generator() -> AsyncIterator[ThreadEvent]: nonlocal state, stream_completed schema_file = create_output_schema_file(schema=output_type_adapter.json_schema()) line_stream: AsyncIterator[str] | None = None try: exec_args = build_exec_args( input=input, options=self._options, thread_options=self._thread_options, thread_id=self._id, turn_options=turn_options, output_schema_path=schema_file.schema_path, ) line_stream = self._exec.run(exec_args) async for line in line_stream: event = parse_thread_event_jsonl(line) if event is None: continue if isinstance(event, ThreadStartedEvent): self._id = event.thread_id logger.info("Assigned async thread ID from event stream: %s", self._id) log_method = ( logger.warning if isinstance(event, (TurnFailedEvent, ThreadErrorEvent)) else logger.debug ) log_method( "Received event class=%s type=%s thread_id=%s event=%r", event.__class__.__name__, event.type, self._id, event, ) state = reduce_turn_state(state, event) yield event stream_completed = True logger.info("Completed async streamed turn request (thread_id=%s)", self._id) finally: if line_stream is not None: await _aclose_if_possible(line_stream) schema_file.cleanup() logger.debug("Cleaned up async streamed turn resources (thread_id=%s)", self._id) return AsyncRunStreamedResult(events=event_generator(), result_factory=build_result)
[docs] async def run( self, input: Input, # noqa: A002 output_type: type[T] | None = None, **turn_options: Unpack[TurnOptions], ) -> RunResult[T]: """Provide input to the agent and return the completed turn. Set `turn_options["signal"]` via `event.set()` to request cancellation. Use `asyncio.Event` in asynchronous flows. Returns: The completed turn with reduced items, final response, and usage. """ logger.info("Running async turn request to completion (thread_id=%s)", self._id) streamed = await self.run_streamed(input, output_type=output_type, **turn_options) events = streamed.events try: async for _event in events: pass finally: await _aclose_if_possible(events) logger.info("Completed async turn request (thread_id=%s)", self._id) return streamed.result
def _close_if_possible(iterator: object) -> None: close_method = getattr(iterator, "close", None) if close_method is None: return cast("Callable[[], None]", close_method)() async def _aclose_if_possible(iterator: object) -> None: close_method = getattr(iterator, "aclose", None) if close_method is None: return await cast("Callable[[], Awaitable[None]]", close_method)()