mirror of
https://github.com/farcasclaudiu/TradingAgents.git
synced 2026-06-29 01:01:33 +03:00
feat: add footer statistics tracking with LangChain callbacks
- Add StatsCallbackHandler for tracking LLM calls, tool calls, and tokens - Integrate callbacks into TradingAgentsGraph and all LLM clients - Dynamic agent/report counts based on selected analysts - Fix report completion counting (tied to agent completion)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
# TradingAgents/graph/propagation.py
|
||||
|
||||
from typing import Dict, Any
|
||||
from typing import Dict, Any, List, Optional
|
||||
from tradingagents.agents.utils.agent_states import (
|
||||
AgentState,
|
||||
InvestDebateState,
|
||||
@@ -41,9 +41,17 @@ class Propagator:
|
||||
"news_report": "",
|
||||
}
|
||||
|
||||
def get_graph_args(self) -> Dict[str, Any]:
|
||||
"""Get arguments for the graph invocation."""
|
||||
def get_graph_args(self, callbacks: Optional[List] = None) -> Dict[str, Any]:
|
||||
"""Get arguments for the graph invocation.
|
||||
|
||||
Args:
|
||||
callbacks: Optional list of callback handlers for tool execution tracking.
|
||||
Note: LLM callbacks are handled separately via LLM constructor.
|
||||
"""
|
||||
config = {"recursion_limit": self.max_recur_limit}
|
||||
if callbacks:
|
||||
config["callbacks"] = callbacks
|
||||
return {
|
||||
"stream_mode": "values",
|
||||
"config": {"recursion_limit": self.max_recur_limit},
|
||||
"config": config,
|
||||
}
|
||||
|
||||
@@ -48,6 +48,7 @@ class TradingAgentsGraph:
|
||||
selected_analysts=["market", "social", "news", "fundamentals"],
|
||||
debug=False,
|
||||
config: Dict[str, Any] = None,
|
||||
callbacks: Optional[List] = None,
|
||||
):
|
||||
"""Initialize the trading agents graph and components.
|
||||
|
||||
@@ -55,9 +56,11 @@ class TradingAgentsGraph:
|
||||
selected_analysts: List of analyst types to include
|
||||
debug: Whether to run in debug mode
|
||||
config: Configuration dictionary. If None, uses default config
|
||||
callbacks: Optional list of callback handlers (e.g., for tracking LLM/tool stats)
|
||||
"""
|
||||
self.debug = debug
|
||||
self.config = config or DEFAULT_CONFIG
|
||||
self.callbacks = callbacks or []
|
||||
|
||||
# Update the interface's config
|
||||
set_config(self.config)
|
||||
@@ -71,6 +74,10 @@ class TradingAgentsGraph:
|
||||
# Initialize LLMs with provider-specific thinking configuration
|
||||
llm_kwargs = self._get_provider_kwargs()
|
||||
|
||||
# Add callbacks to kwargs if provided (passed to LLM constructor)
|
||||
if self.callbacks:
|
||||
llm_kwargs["callbacks"] = self.callbacks
|
||||
|
||||
deep_client = create_llm_client(
|
||||
provider=self.config["llm_provider"],
|
||||
model=self.config["deep_think_llm"],
|
||||
@@ -83,6 +90,7 @@ class TradingAgentsGraph:
|
||||
base_url=self.config.get("backend_url"),
|
||||
**llm_kwargs,
|
||||
)
|
||||
|
||||
self.deep_thinking_llm = deep_client.get_llm()
|
||||
self.quick_thinking_llm = quick_client.get_llm()
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ class AnthropicClient(BaseLLMClient):
|
||||
"""Return configured ChatAnthropic instance."""
|
||||
llm_kwargs = {"model": self.model}
|
||||
|
||||
for key in ("timeout", "max_retries", "api_key", "max_tokens"):
|
||||
for key in ("timeout", "max_retries", "api_key", "max_tokens", "callbacks"):
|
||||
if key in self.kwargs:
|
||||
llm_kwargs[key] = self.kwargs[key]
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from .base_client import BaseLLMClient
|
||||
from .openai_client import OpenAIClient
|
||||
from .anthropic_client import AnthropicClient
|
||||
from .google_client import GoogleClient
|
||||
from .vllm_client import VLLMClient
|
||||
|
||||
|
||||
def create_llm_client(
|
||||
@@ -16,7 +15,7 @@ def create_llm_client(
|
||||
"""Create an LLM client for the specified provider.
|
||||
|
||||
Args:
|
||||
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter, vllm)
|
||||
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter)
|
||||
model: Model name/identifier
|
||||
base_url: Optional base URL for API endpoint
|
||||
**kwargs: Additional provider-specific arguments
|
||||
@@ -41,7 +40,4 @@ def create_llm_client(
|
||||
if provider_lower == "google":
|
||||
return GoogleClient(model, base_url, **kwargs)
|
||||
|
||||
if provider_lower == "vllm":
|
||||
return VLLMClient(model, base_url, **kwargs)
|
||||
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
||||
@@ -38,7 +38,7 @@ class GoogleClient(BaseLLMClient):
|
||||
"""Return configured ChatGoogleGenerativeAI instance."""
|
||||
llm_kwargs = {"model": self.model}
|
||||
|
||||
for key in ("timeout", "max_retries", "google_api_key"):
|
||||
for key in ("timeout", "max_retries", "google_api_key", "callbacks"):
|
||||
if key in self.kwargs:
|
||||
llm_kwargs[key] = self.kwargs[key]
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ class OpenAIClient(BaseLLMClient):
|
||||
elif self.base_url:
|
||||
llm_kwargs["base_url"] = self.base_url
|
||||
|
||||
for key in ("timeout", "max_retries", "reasoning_effort", "api_key"):
|
||||
for key in ("timeout", "max_retries", "reasoning_effort", "api_key", "callbacks"):
|
||||
if key in self.kwargs:
|
||||
llm_kwargs[key] = self.kwargs[key]
|
||||
|
||||
|
||||
@@ -69,11 +69,11 @@ VALID_MODELS = {
|
||||
def validate_model(provider: str, model: str) -> bool:
|
||||
"""Check if model name is valid for the given provider.
|
||||
|
||||
For ollama, openrouter, vllm - any model is accepted.
|
||||
For ollama, openrouter - any model is accepted.
|
||||
"""
|
||||
provider_lower = provider.lower()
|
||||
|
||||
if provider_lower in ("ollama", "openrouter", "vllm"):
|
||||
if provider_lower in ("ollama", "openrouter"):
|
||||
return True
|
||||
|
||||
if provider_lower not in VALID_MODELS:
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from .base_client import BaseLLMClient
|
||||
|
||||
|
||||
class VLLMClient(BaseLLMClient):
|
||||
"""Client for vLLM (placeholder for future implementation)."""
|
||||
|
||||
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
|
||||
super().__init__(model, base_url, **kwargs)
|
||||
|
||||
def get_llm(self) -> Any:
|
||||
"""Return configured vLLM instance."""
|
||||
raise NotImplementedError("vLLM client not yet implemented")
|
||||
|
||||
def validate_model(self) -> bool:
|
||||
"""Validate model for vLLM."""
|
||||
return True
|
||||
Reference in New Issue
Block a user