Extend `stub.py` to accept external typehinting (#1102)

This commit is contained in:
Lukas Kreussel 2023-10-17 12:07:26 +02:00 committed by GitHub
parent b355ab4e2e
commit f9e93f5b69
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 146 additions and 4 deletions

View File

@ -0,0 +1,3 @@
This python module contains external typehinting for certain `candle` classes. This is only necessary for `magic` methodes e.g. `__add__` as their text signature cant be set via pyo3.
The classes in this module will be parsed by the `stub.py` script and interleafed with the signatures of the actual pyo3 `candle.candle` module.

View File

@ -0,0 +1,55 @@
from typing import Union, Sequence
class Tensor:
"""
This contains the type hints for the magic methodes of the `candle.Tensor` class.
"""
def __add__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Add a scalar to a tensor or two tensors together.
"""
pass
def __radd__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Add a scalar to a tensor or two tensors together.
"""
pass
def __sub__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Subtract a scalar from a tensor or one tensor from another.
"""
pass
def __truediv__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Divide a tensor by a scalar or one tensor by another.
"""
pass
def __mul__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Multiply a tensor by a scalar or one tensor by another.
"""
pass
def __rmul__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Multiply a tensor by a scalar or one tensor by another.
"""
pass
def __richcmp__(self, rhs: Union["Tensor", "Scalar"], op) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __getitem__(self, index: Union["Index", "Tensor", Sequence["Index"]]) -> "Tensor":
"""
Return a slice of a tensor.
"""
pass

View File

@ -1,7 +1,7 @@
# Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
from candle.typing import _ArrayLike, Device
from candle.typing import _ArrayLike, Device, Scalar, Index
class bf16(DType):
pass
@ -119,6 +119,46 @@ class Tensor:
def __init__(self, data: _ArrayLike):
pass
def __add__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Add a scalar to a tensor or two tensors together.
"""
pass
def __getitem__(self, index: Union[Index, Tensor, Sequence[Index]]) -> "Tensor":
"""
Return a slice of a tensor.
"""
pass
def __mul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Multiply a tensor by a scalar or one tensor by another.
"""
pass
def __radd__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Add a scalar to a tensor or two tensors together.
"""
pass
def __richcmp__(self, rhs: Union[Tensor, Scalar], op) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __rmul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Multiply a tensor by a scalar or one tensor by another.
"""
pass
def __sub__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Subtract a scalar from a tensor or one tensor from another.
"""
pass
def __truediv__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Divide a tensor by a scalar or one tensor by another.
"""
pass
def argmax_keepdim(self, dim: int) -> Tensor:
"""
Returns the indices of the maximum value(s) across the selected dimension.

View File

@ -1,7 +1,7 @@
# Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
from candle.typing import _ArrayLike, Device
from candle.typing import _ArrayLike, Device, Scalar, Index
from candle import Tensor, DType, QTensor
@staticmethod

View File

@ -14,3 +14,7 @@ CPU: str = "cpu"
CUDA: str = "cuda"
Device = TypeVar("Device", CPU, CUDA)
Scalar = Union[int, float]
Index = Union[int, slice, None, "Ellipsis"]

View File

@ -1,7 +1,7 @@
# Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
from candle.typing import _ArrayLike, Device
from candle.typing import _ArrayLike, Device, Scalar, Index
from candle import Tensor, DType, QTensor
@staticmethod

View File

@ -5,6 +5,7 @@ import os
from typing import Optional
import black
from pathlib import Path
import re
INDENT = " " * 4
@ -12,9 +13,11 @@ GENERATED_COMMENT = "# Generated content DO NOT EDIT\n"
TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
"""
CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device\n"
CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device, Scalar, Index\n"
CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType,QTensor\n"
RETURN_TYPE_MARKER = "&RETURNS&: "
ADDITIONAL_TYPEHINTS = {}
FORWARD_REF_PATTERN = re.compile(r"ForwardRef\('([^']+)'\)")
def do_indent(text: Optional[str], indent: str):
@ -115,6 +118,27 @@ def pyi_file(obj, indent=""):
body += f"{indent+INDENT}pass\n"
body += "\n"
if obj.__name__ in ADDITIONAL_TYPEHINTS:
additional_members = inspect.getmembers(ADDITIONAL_TYPEHINTS[obj.__name__])
additional_functions = []
for name, member in additional_members:
if inspect.isfunction(member):
additional_functions.append((name, member))
def process_additional_function(fn):
signature = inspect.signature(fn)
cleaned_signature = re.sub(FORWARD_REF_PATTERN, r"\1", str(signature))
string = f"{indent}def {fn.__name__}{cleaned_signature}:\n"
string += (
f'{indent+INDENT}"""{indent+INDENT}{do_indent(fn.__doc__, indent+INDENT)}{indent+INDENT}"""\n'
)
string += f"{indent+INDENT}pass\n"
string += "\n"
return string
for name, fn in additional_functions:
body += process_additional_function(fn)
for name, fn in fns:
body += pyi_file(fn, indent=indent)
@ -215,6 +239,19 @@ def write(module, directory, origin, check=False):
write(submodule, os.path.join(directory, name), f"{name}", check=check)
def extract_additional_types(module):
additional_types = {}
for name, member in inspect.getmembers(module):
if inspect.isclass(member):
if hasattr(member, "__name__"):
name = member.__name__
else:
name = str(member)
if name not in additional_types:
additional_types[name] = member
return additional_types
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--check", action="store_true")
@ -228,5 +265,8 @@ if __name__ == "__main__":
directory = f"candle-pyo3/{directory}"
import candle
import _additional_typing
ADDITIONAL_TYPEHINTS = extract_additional_types(_additional_typing)
write(candle.candle, directory, "candle", check=args.check)