feat: add footer statistics tracking with LangChain callbacks

- Add StatsCallbackHandler for tracking LLM calls, tool calls, and tokens
- Integrate callbacks into TradingAgentsGraph and all LLM clients
- Dynamic agent/report counts based on selected analysts
- Fix report completion counting (tied to agent completion)
This commit is contained in:
Yijia Xiao
2026-02-02 22:00:37 +00:00
parent b06936f420
commit 54cdb146d0
10 changed files with 277 additions and 112 deletions
+12 -4
View File
@@ -1,6 +1,6 @@
# TradingAgents/graph/propagation.py
from typing import Dict, Any
from typing import Dict, Any, List, Optional
from tradingagents.agents.utils.agent_states import (
AgentState,
InvestDebateState,
@@ -41,9 +41,17 @@ class Propagator:
"news_report": "",
}
def get_graph_args(self) -> Dict[str, Any]:
"""Get arguments for the graph invocation."""
def get_graph_args(self, callbacks: Optional[List] = None) -> Dict[str, Any]:
"""Get arguments for the graph invocation.
Args:
callbacks: Optional list of callback handlers for tool execution tracking.
Note: LLM callbacks are handled separately via LLM constructor.
"""
config = {"recursion_limit": self.max_recur_limit}
if callbacks:
config["callbacks"] = callbacks
return {
"stream_mode": "values",
"config": {"recursion_limit": self.max_recur_limit},
"config": config,
}
+8
View File
@@ -48,6 +48,7 @@ class TradingAgentsGraph:
selected_analysts=["market", "social", "news", "fundamentals"],
debug=False,
config: Dict[str, Any] = None,
callbacks: Optional[List] = None,
):
"""Initialize the trading agents graph and components.
@@ -55,9 +56,11 @@ class TradingAgentsGraph:
selected_analysts: List of analyst types to include
debug: Whether to run in debug mode
config: Configuration dictionary. If None, uses default config
callbacks: Optional list of callback handlers (e.g., for tracking LLM/tool stats)
"""
self.debug = debug
self.config = config or DEFAULT_CONFIG
self.callbacks = callbacks or []
# Update the interface's config
set_config(self.config)
@@ -71,6 +74,10 @@ class TradingAgentsGraph:
# Initialize LLMs with provider-specific thinking configuration
llm_kwargs = self._get_provider_kwargs()
# Add callbacks to kwargs if provided (passed to LLM constructor)
if self.callbacks:
llm_kwargs["callbacks"] = self.callbacks
deep_client = create_llm_client(
provider=self.config["llm_provider"],
model=self.config["deep_think_llm"],
@@ -83,6 +90,7 @@ class TradingAgentsGraph:
base_url=self.config.get("backend_url"),
**llm_kwargs,
)
self.deep_thinking_llm = deep_client.get_llm()
self.quick_thinking_llm = quick_client.get_llm()
@@ -16,7 +16,7 @@ class AnthropicClient(BaseLLMClient):
"""Return configured ChatAnthropic instance."""
llm_kwargs = {"model": self.model}
for key in ("timeout", "max_retries", "api_key", "max_tokens"):
for key in ("timeout", "max_retries", "api_key", "max_tokens", "callbacks"):
if key in self.kwargs:
llm_kwargs[key] = self.kwargs[key]
+1 -5
View File
@@ -4,7 +4,6 @@ from .base_client import BaseLLMClient
from .openai_client import OpenAIClient
from .anthropic_client import AnthropicClient
from .google_client import GoogleClient
from .vllm_client import VLLMClient
def create_llm_client(
@@ -16,7 +15,7 @@ def create_llm_client(
"""Create an LLM client for the specified provider.
Args:
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter, vllm)
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter)
model: Model name/identifier
base_url: Optional base URL for API endpoint
**kwargs: Additional provider-specific arguments
@@ -41,7 +40,4 @@ def create_llm_client(
if provider_lower == "google":
return GoogleClient(model, base_url, **kwargs)
if provider_lower == "vllm":
return VLLMClient(model, base_url, **kwargs)
raise ValueError(f"Unsupported LLM provider: {provider}")
+1 -1
View File
@@ -38,7 +38,7 @@ class GoogleClient(BaseLLMClient):
"""Return configured ChatGoogleGenerativeAI instance."""
llm_kwargs = {"model": self.model}
for key in ("timeout", "max_retries", "google_api_key"):
for key in ("timeout", "max_retries", "google_api_key", "callbacks"):
if key in self.kwargs:
llm_kwargs[key] = self.kwargs[key]
+1 -1
View File
@@ -60,7 +60,7 @@ class OpenAIClient(BaseLLMClient):
elif self.base_url:
llm_kwargs["base_url"] = self.base_url
for key in ("timeout", "max_retries", "reasoning_effort", "api_key"):
for key in ("timeout", "max_retries", "reasoning_effort", "api_key", "callbacks"):
if key in self.kwargs:
llm_kwargs[key] = self.kwargs[key]
+2 -2
View File
@@ -69,11 +69,11 @@ VALID_MODELS = {
def validate_model(provider: str, model: str) -> bool:
"""Check if model name is valid for the given provider.
For ollama, openrouter, vllm - any model is accepted.
For ollama, openrouter - any model is accepted.
"""
provider_lower = provider.lower()
if provider_lower in ("ollama", "openrouter", "vllm"):
if provider_lower in ("ollama", "openrouter"):
return True
if provider_lower not in VALID_MODELS:
-18
View File
@@ -1,18 +0,0 @@
from typing import Any, Optional
from .base_client import BaseLLMClient
class VLLMClient(BaseLLMClient):
"""Client for vLLM (placeholder for future implementation)."""
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
super().__init__(model, base_url, **kwargs)
def get_llm(self) -> Any:
"""Return configured vLLM instance."""
raise NotImplementedError("vLLM client not yet implemented")
def validate_model(self) -> bool:
"""Validate model for vLLM."""
return True