mirror of
https://github.com/farcasclaudiu/TradingAgents.git
synced 2026-06-29 11:01:40 +03:00
feat: add footer statistics tracking with LangChain callbacks
- Add StatsCallbackHandler for tracking LLM calls, tool calls, and tokens - Integrate callbacks into TradingAgentsGraph and all LLM clients - Dynamic agent/report counts based on selected analysts - Fix report completion counting (tied to agent completion)
This commit is contained in:
@@ -16,7 +16,7 @@ class AnthropicClient(BaseLLMClient):
|
||||
"""Return configured ChatAnthropic instance."""
|
||||
llm_kwargs = {"model": self.model}
|
||||
|
||||
for key in ("timeout", "max_retries", "api_key", "max_tokens"):
|
||||
for key in ("timeout", "max_retries", "api_key", "max_tokens", "callbacks"):
|
||||
if key in self.kwargs:
|
||||
llm_kwargs[key] = self.kwargs[key]
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from .base_client import BaseLLMClient
|
||||
from .openai_client import OpenAIClient
|
||||
from .anthropic_client import AnthropicClient
|
||||
from .google_client import GoogleClient
|
||||
from .vllm_client import VLLMClient
|
||||
|
||||
|
||||
def create_llm_client(
|
||||
@@ -16,7 +15,7 @@ def create_llm_client(
|
||||
"""Create an LLM client for the specified provider.
|
||||
|
||||
Args:
|
||||
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter, vllm)
|
||||
provider: LLM provider (openai, anthropic, google, xai, ollama, openrouter)
|
||||
model: Model name/identifier
|
||||
base_url: Optional base URL for API endpoint
|
||||
**kwargs: Additional provider-specific arguments
|
||||
@@ -41,7 +40,4 @@ def create_llm_client(
|
||||
if provider_lower == "google":
|
||||
return GoogleClient(model, base_url, **kwargs)
|
||||
|
||||
if provider_lower == "vllm":
|
||||
return VLLMClient(model, base_url, **kwargs)
|
||||
|
||||
raise ValueError(f"Unsupported LLM provider: {provider}")
|
||||
|
||||
@@ -38,7 +38,7 @@ class GoogleClient(BaseLLMClient):
|
||||
"""Return configured ChatGoogleGenerativeAI instance."""
|
||||
llm_kwargs = {"model": self.model}
|
||||
|
||||
for key in ("timeout", "max_retries", "google_api_key"):
|
||||
for key in ("timeout", "max_retries", "google_api_key", "callbacks"):
|
||||
if key in self.kwargs:
|
||||
llm_kwargs[key] = self.kwargs[key]
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ class OpenAIClient(BaseLLMClient):
|
||||
elif self.base_url:
|
||||
llm_kwargs["base_url"] = self.base_url
|
||||
|
||||
for key in ("timeout", "max_retries", "reasoning_effort", "api_key"):
|
||||
for key in ("timeout", "max_retries", "reasoning_effort", "api_key", "callbacks"):
|
||||
if key in self.kwargs:
|
||||
llm_kwargs[key] = self.kwargs[key]
|
||||
|
||||
|
||||
@@ -69,11 +69,11 @@ VALID_MODELS = {
|
||||
def validate_model(provider: str, model: str) -> bool:
|
||||
"""Check if model name is valid for the given provider.
|
||||
|
||||
For ollama, openrouter, vllm - any model is accepted.
|
||||
For ollama, openrouter - any model is accepted.
|
||||
"""
|
||||
provider_lower = provider.lower()
|
||||
|
||||
if provider_lower in ("ollama", "openrouter", "vllm"):
|
||||
if provider_lower in ("ollama", "openrouter"):
|
||||
return True
|
||||
|
||||
if provider_lower not in VALID_MODELS:
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from .base_client import BaseLLMClient
|
||||
|
||||
|
||||
class VLLMClient(BaseLLMClient):
|
||||
"""Client for vLLM (placeholder for future implementation)."""
|
||||
|
||||
def __init__(self, model: str, base_url: Optional[str] = None, **kwargs):
|
||||
super().__init__(model, base_url, **kwargs)
|
||||
|
||||
def get_llm(self) -> Any:
|
||||
"""Return configured vLLM instance."""
|
||||
raise NotImplementedError("vLLM client not yet implemented")
|
||||
|
||||
def validate_model(self) -> bool:
|
||||
"""Validate model for vLLM."""
|
||||
return True
|
||||
Reference in New Issue
Block a user