Extend `stub.py` to accept external typehinting (#1102)
This commit is contained in:
parent
b355ab4e2e
commit
f9e93f5b69
|
@ -0,0 +1,3 @@
|
|||
This python module contains external typehinting for certain `candle` classes. This is only necessary for `magic` methodes e.g. `__add__` as their text signature cant be set via pyo3.
|
||||
|
||||
The classes in this module will be parsed by the `stub.py` script and interleafed with the signatures of the actual pyo3 `candle.candle` module.
|
|
@ -0,0 +1,55 @@
|
|||
from typing import Union, Sequence
|
||||
|
||||
|
||||
class Tensor:
|
||||
"""
|
||||
This contains the type hints for the magic methodes of the `candle.Tensor` class.
|
||||
"""
|
||||
|
||||
def __add__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
|
||||
"""
|
||||
Add a scalar to a tensor or two tensors together.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __radd__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
|
||||
"""
|
||||
Add a scalar to a tensor or two tensors together.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __sub__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
|
||||
"""
|
||||
Subtract a scalar from a tensor or one tensor from another.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __truediv__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
|
||||
"""
|
||||
Divide a tensor by a scalar or one tensor by another.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __mul__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
|
||||
"""
|
||||
Multiply a tensor by a scalar or one tensor by another.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __rmul__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
|
||||
"""
|
||||
Multiply a tensor by a scalar or one tensor by another.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __richcmp__(self, rhs: Union["Tensor", "Scalar"], op) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
|
||||
def __getitem__(self, index: Union["Index", "Tensor", Sequence["Index"]]) -> "Tensor":
|
||||
"""
|
||||
Return a slice of a tensor.
|
||||
"""
|
||||
pass
|
|
@ -1,7 +1,7 @@
|
|||
# Generated content DO NOT EDIT
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
||||
from os import PathLike
|
||||
from candle.typing import _ArrayLike, Device
|
||||
from candle.typing import _ArrayLike, Device, Scalar, Index
|
||||
|
||||
class bf16(DType):
|
||||
pass
|
||||
|
@ -119,6 +119,46 @@ class Tensor:
|
|||
|
||||
def __init__(self, data: _ArrayLike):
|
||||
pass
|
||||
def __add__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Add a scalar to a tensor or two tensors together.
|
||||
"""
|
||||
pass
|
||||
def __getitem__(self, index: Union[Index, Tensor, Sequence[Index]]) -> "Tensor":
|
||||
"""
|
||||
Return a slice of a tensor.
|
||||
"""
|
||||
pass
|
||||
def __mul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Multiply a tensor by a scalar or one tensor by another.
|
||||
"""
|
||||
pass
|
||||
def __radd__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Add a scalar to a tensor or two tensors together.
|
||||
"""
|
||||
pass
|
||||
def __richcmp__(self, rhs: Union[Tensor, Scalar], op) -> "Tensor":
|
||||
"""
|
||||
Compare a tensor with a scalar or one tensor with another.
|
||||
"""
|
||||
pass
|
||||
def __rmul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Multiply a tensor by a scalar or one tensor by another.
|
||||
"""
|
||||
pass
|
||||
def __sub__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Subtract a scalar from a tensor or one tensor from another.
|
||||
"""
|
||||
pass
|
||||
def __truediv__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||
"""
|
||||
Divide a tensor by a scalar or one tensor by another.
|
||||
"""
|
||||
pass
|
||||
def argmax_keepdim(self, dim: int) -> Tensor:
|
||||
"""
|
||||
Returns the indices of the maximum value(s) across the selected dimension.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Generated content DO NOT EDIT
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
||||
from os import PathLike
|
||||
from candle.typing import _ArrayLike, Device
|
||||
from candle.typing import _ArrayLike, Device, Scalar, Index
|
||||
from candle import Tensor, DType, QTensor
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -14,3 +14,7 @@ CPU: str = "cpu"
|
|||
CUDA: str = "cuda"
|
||||
|
||||
Device = TypeVar("Device", CPU, CUDA)
|
||||
|
||||
Scalar = Union[int, float]
|
||||
|
||||
Index = Union[int, slice, None, "Ellipsis"]
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Generated content DO NOT EDIT
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
||||
from os import PathLike
|
||||
from candle.typing import _ArrayLike, Device
|
||||
from candle.typing import _ArrayLike, Device, Scalar, Index
|
||||
from candle import Tensor, DType, QTensor
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -5,6 +5,7 @@ import os
|
|||
from typing import Optional
|
||||
import black
|
||||
from pathlib import Path
|
||||
import re
|
||||
|
||||
|
||||
INDENT = " " * 4
|
||||
|
@ -12,9 +13,11 @@ GENERATED_COMMENT = "# Generated content DO NOT EDIT\n"
|
|||
TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
||||
from os import PathLike
|
||||
"""
|
||||
CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device\n"
|
||||
CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device, Scalar, Index\n"
|
||||
CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType,QTensor\n"
|
||||
RETURN_TYPE_MARKER = "&RETURNS&: "
|
||||
ADDITIONAL_TYPEHINTS = {}
|
||||
FORWARD_REF_PATTERN = re.compile(r"ForwardRef\('([^']+)'\)")
|
||||
|
||||
|
||||
def do_indent(text: Optional[str], indent: str):
|
||||
|
@ -115,6 +118,27 @@ def pyi_file(obj, indent=""):
|
|||
body += f"{indent+INDENT}pass\n"
|
||||
body += "\n"
|
||||
|
||||
if obj.__name__ in ADDITIONAL_TYPEHINTS:
|
||||
additional_members = inspect.getmembers(ADDITIONAL_TYPEHINTS[obj.__name__])
|
||||
additional_functions = []
|
||||
for name, member in additional_members:
|
||||
if inspect.isfunction(member):
|
||||
additional_functions.append((name, member))
|
||||
|
||||
def process_additional_function(fn):
|
||||
signature = inspect.signature(fn)
|
||||
cleaned_signature = re.sub(FORWARD_REF_PATTERN, r"\1", str(signature))
|
||||
string = f"{indent}def {fn.__name__}{cleaned_signature}:\n"
|
||||
string += (
|
||||
f'{indent+INDENT}"""{indent+INDENT}{do_indent(fn.__doc__, indent+INDENT)}{indent+INDENT}"""\n'
|
||||
)
|
||||
string += f"{indent+INDENT}pass\n"
|
||||
string += "\n"
|
||||
return string
|
||||
|
||||
for name, fn in additional_functions:
|
||||
body += process_additional_function(fn)
|
||||
|
||||
for name, fn in fns:
|
||||
body += pyi_file(fn, indent=indent)
|
||||
|
||||
|
@ -215,6 +239,19 @@ def write(module, directory, origin, check=False):
|
|||
write(submodule, os.path.join(directory, name), f"{name}", check=check)
|
||||
|
||||
|
||||
def extract_additional_types(module):
|
||||
additional_types = {}
|
||||
for name, member in inspect.getmembers(module):
|
||||
if inspect.isclass(member):
|
||||
if hasattr(member, "__name__"):
|
||||
name = member.__name__
|
||||
else:
|
||||
name = str(member)
|
||||
if name not in additional_types:
|
||||
additional_types[name] = member
|
||||
return additional_types
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--check", action="store_true")
|
||||
|
@ -228,5 +265,8 @@ if __name__ == "__main__":
|
|||
directory = f"candle-pyo3/{directory}"
|
||||
|
||||
import candle
|
||||
import _additional_typing
|
||||
|
||||
ADDITIONAL_TYPEHINTS = extract_additional_types(_additional_typing)
|
||||
|
||||
write(candle.candle, directory, "candle", check=args.check)
|
||||
|
|
Loading…
Reference in New Issue