KAG/kag/common/llm/client/llm_client.py

181 lines
6.6 KiB
Python

# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
import os
import json
from pathlib import Path
from typing import Union, Dict, List, Any
import logging
import traceback
import yaml
from kag.common.base.prompt_op import PromptOp
from kag.common.llm.config import *
logger = logging.getLogger(__name__)
config_cls_map = {
"maas": OpenAIConfig,
"vllm": VLLMConfig,
"ollama": OllamaConfig,
}
def get_config_cls(config:dict):
client_type = config.get("client_type", None)
return config_cls_map.get(client_type, None)
def get_llm_cls(config: LLMConfig):
from kag.common.llm.client import VLLMClient,OpenAIClient,OllamaClient
return {
VLLMConfig: VLLMClient,
OpenAIConfig: OpenAIClient,
OllamaConfig: OllamaClient,
}[config.__class__]
class LLMClient:
# Define the model type
model: str
def __init__(self, **kwargs):
self.model = kwargs.get("model", None)
@classmethod
def from_config(cls, config: Union[str, dict]):
"""
Initialize an LLMClient instance from a configuration file or dictionary.
:param config: Path to a configuration file or a configuration dictionary
:return: Initialized LLMClient instance
:raises FileNotFoundError: If the configuration file is not found
:raises ValueError: If the model type is unsupported
"""
if isinstance(config, str):
config_path = Path(config)
if config_path.is_file():
try:
with open(config_path, "r") as f:
nn_config = yaml.safe_load(f)
except:
logger.error(f"Failed to parse config file")
raise
else:
logger.error(f"Config file not found: {config}")
raise FileNotFoundError(f"Config file not found: {config}")
else:
# If config is already a dictionary, use it directly
nn_config = config
config_cls = get_config_cls(nn_config)
if config_cls is None:
logger.error(f"Unsupported model type: {nn_config.get('client_type', None)}")
raise ValueError(f"Unsupported model type")
llm_config = config_cls(**nn_config)
llm_cls = get_llm_cls(llm_config)
return llm_cls(llm_config)
def __call__(self, prompt: Union[str, dict, list]) -> str:
"""
Perform inference on the given prompt and return the result.
:param prompt: Input prompt for inference
:return: Inference result
:raises NotImplementedError: If the subclass has not implemented this method
"""
raise NotImplementedError
def call_with_json_parse(self, prompt: Union[str, dict, list]):
"""
Perform inference on the given prompt and attempt to parse the result as JSON.
:param prompt: Input prompt for inference
:return: Parsed result
:raises NotImplementedError: If the subclass has not implemented this method
"""
res = self(prompt)
_end = res.rfind("```")
_start = res.find("```json")
if _end != -1 and _start != -1:
json_str = res[_start + len("```json"): _end].strip()
else:
json_str = res
try:
json_result = json.loads(json_str)
except:
return res
return json_result
def invoke(self, variables: Dict[str, Any], prompt_op: PromptOp, with_json_parse: bool = True):
"""
Call the model and process the result.
:param variables: Variables used to build the prompt
:param prompt_op: Prompt operation object for building and parsing prompts
:param with_json_parse: Whether to attempt parsing the response as JSON
:return: Processed result list
"""
result = []
prompt = prompt_op.build_prompt(variables)
logger.debug(f"Prompt: {prompt}")
if not prompt:
return result
response = ""
try:
response = self.call_with_json_parse(prompt=prompt) if with_json_parse else self(prompt)
logger.debug(f"Response: {response}")
result = prompt_op.parse_response(response, model=self.model, **variables)
logger.debug(f"Result: {result}")
except Exception as e:
import traceback
logger.debug(f"Error {e} during invocation: {traceback.format_exc()}")
return result
def batch(self, variables: Dict[str, Any], prompt_op: PromptOp, with_json_parse: bool = True) -> List:
"""
Batch process prompts.
:param variables: Variables used to build the prompts
:param prompt_op: Prompt operation object for building and parsing prompts
:param with_json_parse: Whether to attempt parsing the response as JSON
:return: List of all processed results
"""
results = []
prompts = prompt_op.build_prompt(variables)
# If there is only one prompt, call the `invoke` method directly
if isinstance(prompts, str):
return self.invoke(variables, prompt_op, with_json_parse=with_json_parse)
for idx, prompt in enumerate(prompts, start=0):
logger.debug(f"Prompt_{idx}: {prompt}")
try:
response = self.call_with_json_parse(prompt=prompt) if with_json_parse else self(prompt)
logger.debug(f"Response_{idx}: {response}")
result = prompt_op.parse_response(response, idx=idx, model=self.model, **variables)
logger.debug(f"Result_{idx}: {result}")
results.extend(result)
except Exception as e:
logger.error(f"Error processing prompt {idx}: {e}")
logger.debug(traceback.format_exc())
continue
return results
if __name__ == "__main__":
from kag.common.env import init_kag_config
configFilePath = "/ossfs/workspace/workspace/openspgapp/openspg/python/kag/kag/common/default_config.cfg"
init_kag_config(configFilePath)
model = eval(os.getenv("KAG_LLM"))
print(model)
llm = LLMClient.from_config(model)
res = llm("who are you?")
print(res)