KAG/knext/command/sub_command/project.py

272 lines
8.6 KiB
Python

# -*- coding: utf-8 -*-
# Copyright 2023 OpenSPG Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.
from collections import OrderedDict
import re
import json
import os
import sys
from configparser import ConfigParser
from pathlib import Path
from ruamel.yaml import YAML
from typing import Optional
import click
from knext.common.utils import copytree, copyfile
from knext.project.client import ProjectClient
from knext.common.env import env, DEFAULT_HOST_ADDR
from kag.common.llm.llm_config_checker import LLMConfigChecker
from kag.common.vectorize_model.vectorize_model_config_checker import (
VectorizeModelConfigChecker,
)
from shutil import copy2
yaml = YAML()
yaml.default_flow_style = False
yaml.indent(mapping=2, sequence=4, offset=2)
def _render_template(namespace: str, tmpl: str, **kwargs):
config_path = kwargs.get("config_path", None)
project_dir = Path(namespace)
if not project_dir.exists():
project_dir.mkdir()
import kag.templates.project
src = Path(kag.templates.project.__path__[0])
copytree(
src,
project_dir.resolve(),
namespace=namespace,
root=namespace,
tmpl=tmpl,
**kwargs,
)
import kag.templates.schema
src = Path(kag.templates.schema.__path__[0]) / f"{{{{{tmpl}}}}}.schema.tmpl"
if not src.exists():
click.secho(
f"ERROR: No such schema template: {tmpl}.schema.tmpl",
fg="bright_red",
)
dst = project_dir.resolve() / "schema" / f"{{{{{tmpl}}}}}.schema.tmpl"
copyfile(src, dst, namespace=namespace, **{tmpl: namespace})
tmpls = [tmpl, "default"] if tmpl != "default" else [tmpl]
# find all .yaml files in project dir
config = yaml.load(Path(config_path).read_text() or "{}")
project_id = kwargs.get("id", None)
config["project"]["id"] = project_id
config_file_path = project_dir.resolve() / "kag_config.yaml"
with open(config_file_path, "w") as config_file:
yaml.dump(config, config_file)
return project_dir
def _recover_project(prj_path: str):
"""
Recover project by a project dir path.
"""
if not Path(prj_path).exists():
click.secho(f"ERROR: No such directory: {prj_path}", fg="bright_red")
sys.exit()
project_name = env.project_config.get("namespace", None)
namespace = env.project_config.get("namespace", None)
desc = env.project_config.get("description", None)
if not namespace:
click.secho(
f"ERROR: No project namespace found in {env.config_path}.",
fg="bright_red",
)
sys.exit()
client = ProjectClient()
project = client.get(namespace=namespace) or client.create(
name=project_name, desc=desc, namespace=namespace, config=json.dumps(env.config)
)
env._config["project"]["id"] = project.id
env.dump()
click.secho(
f"Project [{project_name}] with namespace [{namespace}] was successfully recovered from [{prj_path}].",
fg="bright_green",
)
@click.option("--config_path", help="Path of config.", required=True)
@click.option(
"--tmpl",
help="Template of project, use default if not specified.",
default="default",
type=click.Choice(["default", "medical"], case_sensitive=False),
)
@click.option(
"--delete_cfg",
help="whether delete your defined .yaml file.",
default=False,
hidden=True,
)
def create_project(
config_path: str, tmpl: Optional[str] = None, delete_cfg: bool = False
):
"""
Create new project with a demo case.
"""
config = yaml.load(Path(config_path).read_text() or "{}")
project_config = config.get("project", {})
namespace = project_config.get("namespace", None)
name = project_config.get("namespace", None)
host_addr = project_config.get("host_addr", None)
if not namespace:
click.secho("ERROR: namespace is required.")
sys.exit()
if not re.match(r"^[A-Z][A-Za-z0-9]{0,15}$", namespace):
raise click.BadParameter(
f"Invalid namespace: {namespace}."
f" Must start with an uppercase letter, only contain letters and numbers, and have a maximum length of 16."
)
if not tmpl:
tmpl = "default"
project_id = None
llm_config_checker = LLMConfigChecker()
vectorize_model_config_checker = VectorizeModelConfigChecker()
llm_config = config.get("chat_llm", {})
vectorize_model_config = config.get("vectorizer", {})
try:
llm_config_checker.check(json.dumps(llm_config))
dim = vectorize_model_config_checker.check(json.dumps(vectorize_model_config))
config["vectorizer"]["vector_dimensions"] = dim
except Exception as e:
click.secho(f"Error: {e}", fg="bright_red")
sys.exit()
if host_addr:
client = ProjectClient(host_addr=host_addr)
project = client.create(name=name, namespace=namespace, config=json.dumps(config))
if project and project.id:
project_id = project.id
else:
click.secho("ERROR: host_addr is required.", fg="bright_red")
sys.exit()
project_dir = _render_template(
namespace=namespace,
tmpl=tmpl,
id=project_id,
with_server=(host_addr is not None),
host_addr=host_addr,
name=name,
config_path=config_path,
delete_cfg=delete_cfg,
)
current_dir = os.getcwd()
os.chdir(project_dir)
update_project(project_dir)
os.chdir(current_dir)
if delete_cfg and os.path.exists(config_path):
os.remove(config_path)
click.secho(
f"Project with namespace [{namespace}] was successfully created in {project_dir.resolve()} \n"
+ "You can checkout your project with: \n"
+ f" cd {project_dir}",
fg="bright_green",
)
@click.option("--host_addr", help="Address of spg server.", default=None)
@click.option("--proj_path", help="Path of project.", default=None)
def restore_project(host_addr, proj_path):
if host_addr is None:
host_addr = env.host_addr
if proj_path is None:
proj_path = env.project_path
proj_client = ProjectClient(host_addr=host_addr)
project_wanted = proj_client.get_by_namespace(namespace=env.namespace)
if not project_wanted:
if host_addr:
client = ProjectClient(host_addr=host_addr)
project = client.create(name=env.name, namespace=env.namespace, config=json.dumps(env.config))
project_id = project.id
else:
project_id = project_wanted.id
# write project id and host addr to kag_config.yaml
env._config["project"]["id"] = project_id
env._config["project"]["host_addr"] = host_addr
env.dump()
if proj_path:
_recover_project(proj_path)
update_project(proj_path)
@click.option("--proj_path", help="Path of config.", default=None)
def update_project(proj_path):
if not proj_path:
proj_path = env.project_path
client = ProjectClient(host_addr=env.host_addr)
llm_config_checker = LLMConfigChecker()
vectorize_model_config_checker = VectorizeModelConfigChecker()
llm_config = env.config.get("chat_llm", {})
vectorize_model_config = env.config.get("vectorizer", {})
try:
llm_config_checker.check(json.dumps(llm_config))
dim = vectorize_model_config_checker.check(json.dumps(vectorize_model_config))
env._config["vectorizer"]["vector_dimensions"] = dim
except Exception as e:
click.secho(f"Error: {e}", fg="bright_red")
sys.exit()
client.update(id=env.id, config=json.dumps(env.config))
click.secho(
f"Project [{env.name}] with namespace [{env.namespace}] was successfully updated from [{proj_path}].",
fg="bright_green",
)
@click.option("--host_addr", help="Address of spg server.", default=DEFAULT_HOST_ADDR)
def list_project(host_addr):
client = ProjectClient(
host_addr=host_addr
)
projects = client.get_all()
headers = ["Project Name", "Project ID"]
click.echo(click.style(f"{' | '.join(headers)}", fg="bright_green", bold=True))
click.echo(
click.style(
f"{'-' * (len(headers[0]) + len(headers[1]) + 3)}", fg="bright_green"
)
)
for project_name, project_id in projects.items():
click.echo(
click.style(f"{project_name:<20} | {project_id:<10}", fg="bright_green")
)