!614 为门禁与流水线增加统一测试框架并补充推理UT

Merge pull request !614 from guoxinjie/inference
This commit is contained in:
guoxinjie 2024-02-20 02:05:23 +00:00 committed by i-robot
parent 409848a448
commit c27a6624a1
8 changed files with 650 additions and 1 deletions

View File

@ -51,6 +51,7 @@ class ST_Test:
for shell_file in self.shell_file_list:
success_check(os.system("sh {}".format(shell_file)))
# ===============================================
# UT test, run with pytest, waiting for more ...
# ===============================================

341
tests/pipeline/common.py Normal file
View File

@ -0,0 +1,341 @@
# Copyright (c) Microsoft Corporation.
#
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
# copied from https://github.com/microsoft/DeepSpeed/blob/master/tests/unit/common.py
# reworked/refactored some parts to make it run.
import os
import time
import inspect
import socket
from abc import ABC, abstractmethod
import torch
import torch.multiprocessing as mp
import torch.distributed as dist
import pytest
from _pytest.outcomes import Skipped
from _pytest.fixtures import FixtureLookupError, FixtureFunctionMarker
# Worker timeout for tests that hang
TEST_TIMEOUT = 600
def get_xdist_worker_id():
xdist_worker = os.environ.get("PYTEST_XDIST_WORKER", None)
if xdist_worker is not None:
xdist_worker_id = xdist_worker.replace("gw", "")
return int(xdist_worker_id)
return None
def get_master_port(base_port=29500, port_range_size=1000):
xdist_worker_id = get_xdist_worker_id()
if xdist_worker_id is not None:
# Make xdist workers use different port ranges to avoid race conditions
base_port += port_range_size * xdist_worker_id
# Select first open port in range
port = base_port
max_port = base_port + port_range_size
sock = socket.socket()
while port < max_port:
try:
sock.bind(("", port))
sock.close()
return str(port)
except OSError:
port += 1
raise IOError("no free ports")
class DistributedExec(ABC):
"""
Base class for distributed execution of functions/methods. Contains common
methods needed for DistributedTest and DistributedFixture.
"""
world_size = 2
backend = "nccl"
init_distributed = True
set_dist_env = True
reuse_dist_env = False
_pool_cache = {}
exec_timeout = TEST_TIMEOUT
@abstractmethod
def run(self):
...
def __call__(self, request=None):
self._fixture_kwargs = self._get_fixture_kwargs(request, self.run)
world_size = self.world_size
if not torch.cuda.is_available():
pytest.skip("only supported in accelerator environments.")
if isinstance(world_size, int):
world_size = [world_size]
for procs in world_size:
self._launch_procs(procs)
def _get_fixture_kwargs(self, request, func):
if not request:
return {}
# Grab fixture / parametrize kwargs from pytest request object
fixture_kwargs = {}
params = inspect.getfullargspec(func).args
params.remove("self")
for p in params:
try:
fixture_kwargs[p] = request.getfixturevalue(p)
except FixtureLookupError:
pass # test methods can have kwargs that are not fixtures
return fixture_kwargs
def _launch_procs(self, num_procs):
# Verify we have enough accelerator devices to run this test
if torch.cuda.is_available() and torch.cuda.device_count() < num_procs:
pytest.skip(
f"Skipping test because not enough GPUs are available: {num_procs} required, {torch.cuda.device_count()} available"
)
# Set start method to `forkserver` (or `fork`)
mp.set_start_method("forkserver", force=True)
# Create process pool or use cached one
master_port = None
if self.reuse_dist_env:
if num_procs not in self._pool_cache:
self._pool_cache[num_procs] = mp.Pool(processes=num_procs)
master_port = get_master_port()
pool = self._pool_cache[num_procs]
else:
pool = mp.Pool(processes=num_procs)
master_port = get_master_port()
# Run the test
args = [(local_rank, num_procs, master_port) for local_rank in range(num_procs)]
skip_msgs_async = pool.starmap_async(self._dist_run, args)
try:
skip_msgs = skip_msgs_async.get(self.exec_timeout)
except mp.TimeoutError:
# Shortcut to exit pytest in the case of a hanged test. This
# usually means an environment error and the rest of tests will
# hang (causing super long unit test runtimes)
pytest.exit("Test hanged, exiting", returncode=0)
# Tear down distributed environment and close process pools
self._close_pool(pool, num_procs)
# If we skipped a test, propagate that to this process
if any(skip_msgs):
assert len(set(skip_msgs)) == 1, "Multiple different skip messages received"
pytest.skip(skip_msgs[0])
def _dist_run(self, local_rank, num_procs, master_port):
skip_msg = ""
if not dist.is_initialized():
""" Initialize torch.distributed and execute the user function. """
if self.set_dist_env:
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(master_port)
os.environ["LOCAL_RANK"] = str(local_rank)
# NOTE: unit tests don't support multi-node so local_rank == global rank
os.environ["RANK"] = str(local_rank)
# In case of multiprocess launching LOCAL_SIZE should be same as WORLD_SIZE
# single node launcher would also set LOCAL_SIZE accordingly
os.environ["LOCAL_SIZE"] = str(num_procs)
os.environ["WORLD_SIZE"] = str(num_procs)
print(
f"Initializing torch.distributed with rank: {local_rank}, world_size: {num_procs}"
)
torch.cuda.set_device(local_rank % torch.cuda.device_count())
init_method = "tcp://"
master_ip = os.getenv("MASTER_ADDR", "localhost")
master_port = str(master_port)
init_method += master_ip + ":" + master_port
torch.distributed.init_process_group(
backend=self.backend,
world_size=num_procs,
rank=local_rank,
init_method=init_method,
)
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
try:
self.run(**self._fixture_kwargs)
except BaseException as e:
if isinstance(e, Skipped):
skip_msg = e.msg
else:
raise e
return skip_msg
def _dist_destroy(self):
if (dist is not None) and dist.is_initialized():
dist.barrier()
dist.destroy_process_group()
def _close_pool(self, pool, num_procs, force=False):
if force or not self.reuse_dist_env:
msg = pool.starmap(self._dist_destroy, [() for _ in range(num_procs)])
pool.close()
pool.join()
class DistributedFixture(DistributedExec):
"""
Implementation that extends @pytest.fixture to allow for distributed execution.
This is primarily meant to be used when a test requires executing two pieces of
code with different world sizes.
There are 2 parameters that can be modified:
- world_size: int = 2 -- the number of processes to launch
- backend: Literal['nccl','mpi','gloo'] = 'nccl' -- which backend to use
Features:
- able to call pytest.skip() inside fixture
- can be reused by multiple tests
- can accept other fixtures as input
Limitations:
- cannot use @pytest.mark.parametrize
- world_size cannot be modified after definition and only one world_size value is accepted
- any fixtures used must also be used in the test that uses this fixture (see example below)
- return values cannot be returned. Passing values to a DistributedTest
object can be achieved using class_tmpdir and writing to file (see example below)
Usage:
- must implement a run(self, ...) method
- fixture can be used by making the class name input to a test function
Example:
@pytest.fixture(params=[10,20])
def regular_pytest_fixture(request):
return request.param
class distributed_fixture_example(DistributedFixture):
world_size = 4
def run(self, regular_pytest_fixture, class_tmpdir):
assert int(os.environ["WORLD_SIZE"]) == self.world_size
local_rank = os.environ["LOCAL_RANK"]
print(f"Rank {local_rank} with value {regular_pytest_fixture}")
with open(os.path.join(class_tmpdir, f"{local_rank}.txt"), "w") as f:
f.write(f"{local_rank},{regular_pytest_fixture}")
class TestExample(DistributedTest):
world_size = 1
def test(self, distributed_fixture_example, regular_pytest_fixture, class_tmpdir):
assert int(os.environ["WORLD_SIZE"]) == self.world_size
for rank in range(4):
with open(os.path.join(class_tmpdir, f"{rank}.txt"), "r") as f:
assert f.read() == f"{rank},{regular_pytest_fixture}"
"""
is_dist_fixture = True
# These values are just placeholders so that pytest recognizes this as a fixture
_pytestfixturefunction = FixtureFunctionMarker(scope="function", params=None)
__name__ = ""
def __init__(self):
assert isinstance(
self.world_size, int
), "Only one world size is allowed for distributed fixtures"
self.__name__ = type(self).__name__
_pytestfixturefunction = FixtureFunctionMarker(
scope="function", params=None, name=self.__name__
)
class DistributedTest(DistributedExec):
"""
Implementation for running pytest with distributed execution.
There are 2 parameters that can be modified:
- world_size: Union[int,List[int]] = 2 -- the number of processes to launch
- backend: Literal['nccl','mpi','gloo'] = 'nccl' -- which backend to use
Features:
- able to call pytest.skip() inside tests
- works with pytest fixtures, parametrize, mark, etc.
- can contain multiple tests (each of which can be parametrized separately)
- class methods can be fixtures (usable by tests in this class only)
- world_size can be changed for individual tests using @pytest.mark.world_size(world_size)
- class_tmpdir is a fixture that can be used to get a tmpdir shared among
all tests (including DistributedFixture)
Usage:
- class name must start with "Test"
- must implement one or more test*(self, ...) methods
Example:
@pytest.fixture(params=[10,20])
def val1(request):
return request.param
@pytest.mark.fast
@pytest.mark.parametrize("val2", [30,40])
class TestExample(DistributedTest):
world_size = 2
@pytest.fixture(params=[50,60])
def val3(self, request):
return request.param
def test_1(self, val1, val2, str1="hello world"):
assert int(os.environ["WORLD_SIZE"]) == self.world_size
assert all(val1, val2, str1)
@pytest.mark.world_size(1)
@pytest.mark.parametrize("val4", [70,80])
def test_2(self, val1, val2, val3, val4):
assert int(os.environ["WORLD_SIZE"]) == 1
assert all(val1, val2, val3, val4)
"""
is_dist_test = True
# Temporary directory that is shared among test methods in a class
@pytest.fixture(autouse=True, scope="class")
def class_tmpdir(self, tmpdir_factory):
fn = tmpdir_factory.mktemp(self.__class__.__name__)
return fn
def run(self, **fixture_kwargs):
self._current_test(**fixture_kwargs)
def __call__(self, request):
self._current_test = self._get_current_test_func(request)
self._fixture_kwargs = self._get_fixture_kwargs(request, self._current_test)
if not torch.cuda.is_available():
pytest.skip("only supported in accelerator environments.")
# Catch world_size override pytest mark
for mark in getattr(request.function, "pytestmark", []):
if mark.name == "world_size":
world_size = mark.args[0]
break
else:
world_size = self.world_size
if isinstance(world_size, int):
world_size = [world_size]
for procs in world_size:
self._launch_procs(procs)
time.sleep(0.5)
def _get_current_test_func(self, request):
# DistributedTest subclasses may have multiple test methods
func_name = request.function.__name__
return getattr(self, func_name)

View File

@ -0,0 +1,43 @@
# Copyright (c) Microsoft Corporation.
#
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
# copied from https://github.com/microsoft/DeepSpeed/blob/master/tests/conftest.py
# reworked/refactored some parts to make it run.
import pytest
def pytest_configure(config):
config.option.color = "yes"
config.option.durations = 0
config.option.durations_min = 1
config.option.verbose = True
# Override of pytest "runtest" for DistributedTest class
# This hook is run before the default pytest_runtest_call
@pytest.hookimpl(tryfirst=True)
def pytest_runtest_call(item):
# We want to use our own launching function for distributed tests
if getattr(item.cls, "is_dist_test", False):
dist_test_class = item.cls()
dist_test_class(item._request)
item.runtest = lambda: True # Dummy function so test is not run twice
# We allow DistributedTest to reuse distributed environments. When the last
# test for a class is run, we want to make sure those distributed environments
# are destroyed.
def pytest_runtest_teardown(item, nextitem):
if getattr(item.cls, "reuse_dist_env", False) and not nextitem:
dist_test_class = item.cls()
for num_procs, pool in dist_test_class._pool_cache.items():
dist_test_class._close_pool(pool, num_procs, force=True)
@pytest.hookimpl(tryfirst=True)
def pytest_fixture_setup(fixturedef, request):
if getattr(fixturedef.func, "is_dist_fixture", False):
dist_fixture_class = fixturedef.func()
dist_fixture_class(request)

View File

@ -0,0 +1,8 @@
# Provide uniform access for piepline.
python -m unittest test_process_pretrain_data.py
python -m unittest test_process_instruction_data.py
python -m unittest test_convert_ckpt_from_huggingface.py
pytest -s test_generation.py

View File

@ -0,0 +1,100 @@
{
"NETWORK_SIZE": [
"--num-layers", "32",
"--hidden-size", "4096",
"--ffn-hidden-size", "11008",
"--num-attention-heads", "32",
"--max-position-embeddings", "4096",
"--position-embedding-type", "rope",
"--make-vocab-size-divisible-by", "1",
"--normalization", "RMSNorm",
"--swiglu",
"--untie-embeddings-and-output-weights"
],
"INFERENCE_PARAM": [
"--tokenizer-type", "PretrainedFromHF",
"--tokenizer-model", "./tokenizer.model",
"--tokenizer-name-or-path", "./llama2-7B",
"--load", "./llama2-7B-tp8-pp1",
"--max-new-tokens", "256",
"--seed", "42",
"--tokenizer-not-use-fast",
"--exit-on-missing-checkpoint"
],
"EVALUATION_PARAM": [
"--use-checkpoint-args",
"--task-data-path", "./eval_dataset/boolq/test/", "./eval_dataset/mmlu/test/",
"--task", "boolq",
"--max-new-tokens", "2"
],
"LORA_PARAM": [
"--finetune",
"--lora-r", "16",
"--lora-alpha", "32",
"--lora-target-modules", "query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h",
"--load", "./llama2-7B-tp8-pp1",
"--save", "./lora_weight",
"--data-path", "./alpaca_text_document",
"--train-iters", "10"
],
"TRAINING_PARAM": [
"--save", "./save_weight",
"--data-path", "./alpaca_text_document",
"--train-iters", "15"
],
"REGULARIZATION": [
"--attention-dropout", "0.0",
"--hidden-dropout", "0.0",
"--weight-decay", "1e-1",
"--clip-grad", "1.0",
"--adam-beta1", "0.9",
"--adam-beta2", "0.95"
],
"LEARNING_RATE": [
"--lr", "1.25e-6",
"--lr-decay-style", "cosine",
"--lr-warmup-fraction", "0.01",
"--min-lr", "1.25e-7"
],
"DISTRIBUTED_PARAM": [
"--tensor-model-parallel-size", "8",
"--pipeline-model-parallel-size", "1"
],
"AUXILIARY_PARAM": [
"--micro-batch-size", "4",
"--global-batch-size", "16",
"--no-masked-softmax-fusion",
"--disable-bias-linear",
"--no-gradient-accumulation-fusion",
"--bf16",
"--attention-softmax-in-fp32",
"--no-load-optim",
"--no-load-rng",
"--seq-length", "4096"
],
"TRAINING_AUX": [
"--sequence-parallel",
"--initial-loss-scale", "65536",
"--use-flash-attn",
"--use-fused-rmsnorm",
"--init-method-std", "0.01",
"--split", "100,0,0",
"--log-interval", "1",
"--save-interval", "10",
"--eval-interval", "1000",
"--eval-iters", "0",
"--num-workers", "0",
"--distributed-backend", "nccl",
"--tokenizer-type", "Llama2Tokenizer",
"--tokenizer-model", "./tokenizer.model"
]
}

View File

@ -38,7 +38,7 @@ class TestConvertCkptFromHuggingface(unittest.TestCase):
including embedding, final_norm, output and encoder. In the encoder, there will be some different layers
to compose the unique transformer layer and all these layer stack to compose the entity of the model.
"""
base_dir = Path(__file__).absolute().parent.parent.parent
base_dir = Path(__file__).absolute().parent.parent.parent.parent
file_path = os.path.join(base_dir, "tools/checkpoint/util.py")
arguments = sys.argv[1:]
subprocess.run(["python", file_path] + arguments)

View File

@ -0,0 +1,115 @@
import sys
import os
import nltk
import torch
import torch_npu
from ascendspeed import megatron_adaptor
from common import DistributedTest
from utils import ParamConfig, assert_judge
import modellink
from megatron.model import GPTModel
from tasks.inference.infer_base import add_text_generate_args
class TestGeneration(DistributedTest):
world_size = 8
def init(self, config=ParamConfig):
"""
initialize the environment and arguments
"""
sys.argv = [sys.argv[0]] + config.distributed_param + config.network_size + \
config.inference_param + config.auxiliary_param
from megatron.initialize import initialize_megatron
os.environ.update({"CUDA_DEVICE_MAX_CONNECTIONS": "1"})
initialize_megatron(extra_args_provider=add_text_generate_args,
args_defaults={'no_load_rng': True,
'no_load_optim': True})
from megatron import get_args
self.args = get_args()
def edit_distance_similarity(self, text1, text2):
"""
edit distance: to compare the similarity between two texts.
"""
distance = nltk.edit_distance(text1, text2)
try:
similarity = 1 - (distance / max(len(text1), len(text2)))
except ZeroDivisionError as e:
raise e
return similarity
def test_greedy_search(self):
"""
load weight to get model and construct the prompts to generate output,
and compare with expected for `greedy search`.
"""
self.init(config=ParamConfig)
from tasks.inference.inference_llama import model_provider
model = GPTModel.from_pretrained(
model_provider=model_provider,
pretrained_model_name_or_path=self.args.load
)
instruction = ["how are you?", "Give me three tips for staying healthy."]
output = model.generate(instruction)
expect_output1 = [
"I'm doing well, thanks for asking! I've been keeping busy with work and spending time with friends and family. ",
"It's been great to have some time off from school and just relax a bit. How about you? How have you been?\n",
"\nI hope you're doing well! It's always great to catch up with you and hear about what's going on in your life. ",
"I'm looking forward to hearing all about it. Let me know if you want to hang out soon!"
]
expect_output2 = [
'\n\n1. Eat a balanced diet: A healthy diet should include a variety of fruits, vegetables, whole grains, lean proteins, and healthy fats. ',
'Aim to include a rainbow of colors on your plate to ensure you are getting a range of vitamins and minerals.',
'\n2. Stay hydrated: Drink plenty of water throughout the day, aiming for at least eight cups (64 ounces) daily. ',
'Limit your consumption of sugary drinks'
]
expect_output1_seq = "".join(expect_output1)
expect_output2_seq = ''.join(expect_output2)
if torch.distributed.get_rank() == 0:
print(output[0])
print(output[1])
similarity1 = self.edit_distance_similarity(output[0][:30], expect_output1_seq[:30])
similarity2 = self.edit_distance_similarity(output[1][:30], expect_output2_seq[:30])
print("similarity1:", similarity1)
print("similarity2:", similarity2)
assert_judge(similarity1 > 0.85)
assert_judge(similarity2 > 0.85)
def test_beam_search(self):
"""
load weight to get model and construct the prompts to generate output,
and compare with expected for `beam search`.
"""
self.init(config=ParamConfig)
from tasks.inference.inference_llama import model_provider
model = GPTModel.from_pretrained(
model_provider=model_provider,
pretrained_model_name_or_path=self.args.load
)
max_new_tokens = self.args.max_new_tokens
instruction = "What is the whether like today?"
output = model.generate(
instruction,
num_beams=2,
top_k=self.args.top_k,
top_p=self.args.top_p,
max_new_tokens=max_new_tokens,
tokenizer=None,
stream=False
)
expected_output = [
"Answer:\nThe weather today is sunny with a high of 75 degrees Fahrenheit and a low of 50 degrees Fahrenheit. ",
"There is no rain or other weather alerts in the area.",
"\nWould you like to know the weather for a different location?"
]
expected_output_seq = "".join(expected_output)
if torch.distributed.get_rank() == 0:
similarity = self.edit_distance_similarity(output[:40], expected_output_seq[:40])
print(output)
print("similarity:", similarity)
assert_judge(similarity > 0.75)

View File

@ -0,0 +1,41 @@
import json
import os
from pathlib import Path
from dataclasses import dataclass
@dataclass
class ParamConfig:
"""
We can config the params in the `.json` file including:
distributed_param,
network_size,
inference_param,
evaluation_param,
lora_param,
training_param,
training_auxiliary,
learning_rate,
regularization,
and other auxiliary_param.
"""
base_dir = Path(__file__).absolute().parent
param_config = os.path.join(base_dir, "param_config.json")
with open(param_config) as f:
config_file = json.load(f)
distributed_param = config_file["DISTRIBUTED_PARAM"]
network_size = config_file["NETWORK_SIZE"]
inference_param = config_file["INFERENCE_PARAM"]
evaluation_param = config_file["EVALUATION_PARAM"]
lora_param = config_file["LORA_PARAM"]
training_param = config_file["TRAINING_PARAM"]
training_aux = config_file["TRAINING_AUX"]
learning_rate_param = config_file["LEARNING_RATE"]
regularization = config_file["REGULARIZATION"]
auxiliary_param = config_file["AUXILIARY_PARAM"]
def assert_judge(expression):
if not expression:
raise AssertionError