Source code for quackamollie.model_manager.langchain.langchain_model_manager
# -*- coding: utf-8 -*-
__all__ = ["LangchainQuackamollieModelManager"]
__author__ = "QuacktorAI"
__copyright__ = "Copyright 2024, Forge of Absurd Ducks"
__credits__ = ["QuacktorAI"]
from quackamollie.core.enum.model_family_icon import ModelFamilyIcon
from quackamollie.core.database.model import ChatMessage
from quackamollie.core.meta.model_manager.meta_quackamollie_model_manager import MetaQuackamollieModelManager
from typing import Dict, List, Optional, Tuple, Type
from quackamollie.model.meta.langchain.langchain_meta_model import MetaLangchainQuackamollieModel
from quackamollie.model_manager.langchain.helpers.langchain_model_entry_point import (
get_langchain_models_from_entrypoints
)
[docs]
class LangchainQuackamollieModelManager(MetaQuackamollieModelManager):
""" Model manager managed by the `QuackamollieModelManagerRegistry` and serving models using Langchain """
families: List[ModelFamilyIcon] = [ModelFamilyIcon.LANGCHAIN]
LANGCHAIN_ENTRYPOINT_GROUP: str = "quackamollie.model.langchain"
_entrypoint_model_dict: Optional[Dict[str, Type[MetaLangchainQuackamollieModel]]] = None
[docs]
@classmethod
async def get_entrypoint_model_dict(cls) -> Optional[Dict[str, Type[MetaLangchainQuackamollieModel]]]:
if cls._entrypoint_model_dict is None:
cls._entrypoint_model_dict = get_langchain_models_from_entrypoints(cls.LANGCHAIN_ENTRYPOINT_GROUP)
return cls._entrypoint_model_dict
[docs]
@classmethod
async def get_model_list(cls) -> Optional[List[str]]:
""" Discover the models available for the model manager at runtime asynchronously
:return: A list of available models for the model manager
:rtype: List[str]
"""
if cls._model_list is None:
entrypoint_model_dict = await cls.get_entrypoint_model_dict()
if entrypoint_model_dict is not None:
cls._model_list = list(entrypoint_model_dict.keys())
return cls._model_list
[docs]
@classmethod
async def get_model_families(cls) -> Dict[str, List[ModelFamilyIcon]]:
""" Discover the models families available for the model manager at runtime asynchronously
:return: A dict with values the list of families indexed by model name
:rtype: Dict[str, List[ModelFamilyIcon]]
"""
if cls._model_families is None:
entrypoint_model_dict = await cls.get_entrypoint_model_dict()
if cls._entrypoint_model_dict is not None:
cls._model_families = {}
for entrypoint_name, model_class in entrypoint_model_dict.items():
cls._model_families[entrypoint_name] = model_class.model_families
return cls._model_families
[docs]
@classmethod
def parse_chat_history(cls, chat_messages: Optional[List[ChatMessage]]) -> List[Tuple[str, str]]:
""" Parse the chat history given as a list of `ChatMessage` from the database model to a list compatible with
the model manager's models.
:param chat_messages: A list of `ChatMessage` from the database model
:param chat_messages: Optional[List[ChatMessage]]
:return: A list of messages formatted to be compatible with the model manager's models.
:rtype: List[Tuple[str, str]]
"""
chat_history: List[Tuple[str, str]] = []
# Construct the list of messages in a format supported by Langchain
if chat_messages:
for past_msg in chat_messages:
chat_history.append((past_msg.user.user_type.value.lower(), past_msg.content))
return chat_history
[docs]
@classmethod
async def get_model_class(cls, model_name: str) -> Optional[Type[MetaLangchainQuackamollieModel]]:
""" Get the model class from the model name
:param model_name: Name of the model as listed by `cls.get_model_list`
:type model_name: str
:return: A subclass of MetaQuackamollieModel
:rtype: Optional[Type[MetaLangchainQuackamollieModel]]
"""
entrypoint_model_dict = await cls.get_entrypoint_model_dict()
if entrypoint_model_dict is None:
return None
else:
return entrypoint_model_dict.get(model_name, None)
[docs]
@classmethod
def reset(cls):
""" Reset the model manager dynamic fields to force reloading models. Be careful if used asynchronously """
cls._entrypoint_model_dict = None
cls._model_list = None
cls._model_families = None