mirror of
https://github.com/farcasclaudiu/TradingAgents.git
synced 2026-06-29 13:01:38 +03:00
Merge pull request #464 from CadeYu/sync-validator-models
sync model validation with cli catalog
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
import warnings
|
||||
|
||||
|
||||
def normalize_content(response):
|
||||
@@ -29,6 +30,27 @@ class BaseLLMClient(ABC):
|
||||
self.base_url = base_url
|
||||
self.kwargs = kwargs
|
||||
|
||||
def get_provider_name(self) -> str:
|
||||
"""Return the provider name used in warning messages."""
|
||||
provider = getattr(self, "provider", None)
|
||||
if provider:
|
||||
return str(provider)
|
||||
return self.__class__.__name__.removesuffix("Client").lower()
|
||||
|
||||
def warn_if_unknown_model(self) -> None:
|
||||
"""Warn when the model is outside the known list for the provider."""
|
||||
if self.validate_model():
|
||||
return
|
||||
|
||||
warnings.warn(
|
||||
(
|
||||
f"Model '{self.model}' is not in the known model list for "
|
||||
f"provider '{self.get_provider_name()}'. Continuing anyway."
|
||||
),
|
||||
RuntimeWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_llm(self) -> Any:
|
||||
"""Return the configured LLM instance."""
|
||||
|
||||
Reference in New Issue
Block a user