feat: add footer statistics tracking with LangChain callbacks

- Add StatsCallbackHandler for tracking LLM calls, tool calls, and tokens
- Integrate callbacks into TradingAgentsGraph and all LLM clients
- Dynamic agent/report counts based on selected analysts
- Fix report completion counting (tied to agent completion)
This commit is contained in:
Yijia Xiao
2026-02-02 22:00:37 +00:00
parent b06936f420
commit 54cdb146d0
10 changed files with 277 additions and 112 deletions
@@ -16,7 +16,7 @@ class AnthropicClient(BaseLLMClient):
"""Return configured ChatAnthropic instance."""
llm_kwargs = {"model": self.model}
for key in ("timeout", "max_retries", "api_key", "max_tokens"):
for key in ("timeout", "max_retries", "api_key", "max_tokens", "callbacks"):
if key in self.kwargs:
llm_kwargs[key] = self.kwargs[key]
+1 -5
View File
@@ -4,7 +4,6 @@ from .base_client import BaseLLMClient
from .openai_client import OpenAIClient
from .anthropic_client import AnthropicClient
from .google_client import GoogleClient
from .vllm_client import VLLMClient
def create_llm_client(
@@ -16,7 +15,7 @@ def create_llm_client(
"""Create an LLM client for the specified provider.
Args:
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter, vllm)
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter)
model: Model name/identifier
base_url: Optional base URL for API endpoint
**kwargs: Additional provider-specific arguments
@@ -41,7 +40,4 @@ def create_llm_client(
if provider_lower == "google":
return GoogleClient(model, base_url, **kwargs)
if provider_lower == "vllm":
return VLLMClient(model, base_url, **kwargs)
raise ValueError(f"Unsupported LLM provider: {provider}")
+1 -1
View File
@@ -38,7 +38,7 @@ class GoogleClient(BaseLLMClient):
"""Return configured ChatGoogleGenerativeAI instance."""
llm_kwargs = {"model": self.model}
for key in ("timeout", "max_retries", "google_api_key"):
for key in ("timeout", "max_retries", "google_api_key", "callbacks"):
if key in self.kwargs:
llm_kwargs[key] = self.kwargs[key]
+1 -1
View File
@@ -60,7 +60,7 @@ class OpenAIClient(BaseLLMClient):
elif self.base_url:
llm_kwargs["base_url"] = self.base_url
for key in ("timeout", "max_retries", "reasoning_effort", "api_key"):
for key in ("timeout", "max_retries", "reasoning_effort", "api_key", "callbacks"):
if key in self.kwargs:
llm_kwargs[key] = self.kwargs[key]
+2 -2
View File
@@ -69,11 +69,11 @@ VALID_MODELS = {
def validate_model(provider: str, model: str) -> bool:
"""Check if model name is valid for the given provider.
For ollama, openrouter, vllm - any model is accepted.
For ollama, openrouter - any model is accepted.
"""
provider_lower = provider.lower()
if provider_lower in ("ollama", "openrouter", "vllm"):
if provider_lower in ("ollama", "openrouter"):
return True
if provider_lower not in VALID_MODELS:
-18
View File
@@ -1,18 +0,0 @@
from typing import Any, Optional
from .base_client import BaseLLMClient
class VLLMClient(BaseLLMClient):
"""Client for vLLM (placeholder for future implementation)."""
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
super().__init__(model, base_url, **kwargs)
def get_llm(self) -> Any:
"""Return configured vLLM instance."""
raise NotImplementedError("vLLM client not yet implemented")
def validate_model(self) -> bool:
"""Validate model for vLLM."""
return True