ModelLink/modellink/error_utils.py

173 lines
4.9 KiB
Python

# coding=utf-8
# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def check_condition(cond):
if not cond:
raise RuntimeError
class NotDivisibleError(Exception):
def __init__(self, denominator, molecule, error_info):
super().__init__()
self._error_info = error_info
self._molecule = molecule
self._denominator = denominator
def __str__(self):
if self._error_info is None:
return f"{self._denominator} is not divisible by {self._molecule}"
else:
return self._error_info.format(self._denominator, self._molecule)
def check_divisible(denominator, molecule, error_info=None):
if denominator % molecule == 0:
return
raise NotDivisibleError(denominator, molecule, error_info)
def check_divisible_by_zero(dividend, divisor):
if not isinstance(divisor, int) or divisor != 0:
return dividend / divisor
raise ZeroDivisionError
class NotEqualError(Exception):
def __init__(self, tensor_a, tensor_b, error_info):
super().__init__()
self._error_info = error_info
self._tensor_a = tensor_a
self._tensor_b = tensor_b
def __str__(self):
if self._error_info is None:
return f"{self._tensor_a} is not equal to {self._tensor_b}"
else:
return self._error_info.format(self._tensor_a, self._tensor_b)
def check_equal(tensor_a, tensor_b, error_info=None):
if tensor_a == tensor_b:
return
raise NotEqualError(tensor_a, tensor_b, error_info)
class NotExistError(Exception):
def __init__(self, item, container, error_info):
super().__init__()
self._error_info = error_info
self._item = item
self._container = container
def __str__(self):
if self._error_info is None:
return f"{self._item} is not in {self._container}"
else:
return self._error_info.format(self._item, self._container)
def check_exist(item, container, error_info=None):
if item in container:
return True
raise NotExistError(item, container, error_info)
class NotExpectedTypeError(Exception):
def __init__(self, variable, expected_type, error_message=None):
super().__init__()
self._variable = variable
self._expected_type = expected_type
self._error_message = error_message
def __str__(self):
if self._error_message:
return self._error_message
else:
return f"expected {self._expected_type.__name__} find {type(self._variable).__name__}"
def check_type(variable, expected_type, error_message=None):
if isinstance(variable, expected_type):
return
else:
raise NotExpectedTypeError(variable, expected_type, error_message)
class IsNotNoneError(Exception):
def __init__(self, error_message):
super().__init__()
self._error_message = error_message
def __repr__(self):
if self._error_message:
return self._error_message
else:
return "expect None variable"
class IsNoneError(Exception):
def __init__(self, error_message):
super().__init__()
self._error_message = error_message
def __repr__(self):
if self._error_message:
return self._error_message
else:
return "expect not None variable"
def ensure_var_is_not_none(variable, error_message=None):
if variable is not None:
return
else:
raise IsNoneError(error_message=error_message)
def ensure_var_is_none(variable, error_message=None):
if variable is None:
return
else:
raise IsNotNoneError(error_message)
class IsNotValidError(Exception):
def __init__(self, error_message):
super().__init__()
self._error_message = error_message
def __repr__(self):
if self._error_message:
return self._error_message
else:
return "Expression is not valid"
def ensure_valid(expression, error_message=None):
if not expression:
raise IsNotValidError(error_message)
class GPTDatasetSampleIndexError(Exception):
def __init__(self, error_message):
super().__init__()
self._error_message = error_message
def __repr__(self):
if self._error_message:
return self._error_message
else:
return "Bad sample index."