Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding type annotations to manim.utils.* #3999

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
6e0d440
Handled mypy issues in utils/bezier.py
henrikmidtiby Nov 4, 2024
8c887b5
Disable mypy errors in manim.utils.*
henrikmidtiby Nov 4, 2024
09b10c7
Fix mypy errors in utils/unit.py
henrikmidtiby Nov 4, 2024
5868534
Handle mypy errors in utils/debug.py
henrikmidtiby Nov 4, 2024
57d4971
Fix mypy issues in utils.color.*
henrikmidtiby Nov 4, 2024
f764767
Avoid circular import.
henrikmidtiby Nov 4, 2024
1021c47
Handle mypy errors in utils.simple_functions.*
henrikmidtiby Nov 5, 2024
0002711
Handle my errors in utils.testing.*
henrikmidtiby Nov 5, 2024
df742bc
Avoid circular import.
henrikmidtiby Nov 5, 2024
ce85521
Handle mypy errors in utils/family_ops.py
henrikmidtiby Nov 6, 2024
21c1a44
Handle mypy errors in utils/parameter_parsing.py
henrikmidtiby Nov 6, 2024
e155bae
Handle some of the mypy errors in utils.docbuild.*
henrikmidtiby Nov 6, 2024
2f5480d
Handle mypy errors for utils/config_ops.py
henrikmidtiby Nov 6, 2024
97227b2
Handle mypy errors from utils/commands.py
henrikmidtiby Nov 6, 2024
3a2eaee
Handle mypy errors in utils/tex_templates.py
henrikmidtiby Nov 6, 2024
d21525d
Handle mypy errors in utils/space_ops.py
henrikmidtiby Nov 6, 2024
1b1e417
Fixed most type errors in utils/rate_functions.py
henrikmidtiby Nov 8, 2024
7bb63bd
Handle type errors in utils/paths.py
henrikmidtiby Nov 8, 2024
a26da20
Handled type errors in utils/deprecation.py
henrikmidtiby Nov 8, 2024
3b2e31d
Handle type errors in utils/tex_file_writing.py
henrikmidtiby Nov 8, 2024
b1e815e
Handle type error in utils/sounds.py
henrikmidtiby Nov 8, 2024
f086fec
Handle type errors in utils/opengl.py
henrikmidtiby Nov 9, 2024
d51f929
Handle type errors in utils/module_ops.py
henrikmidtiby Nov 9, 2024
b0bd14e
Handle type errors in utils/caching.py, utils/familily.py and utils/i…
henrikmidtiby Nov 9, 2024
2f151d7
Handle type errors in utils/file_ops.py
henrikmidtiby Nov 9, 2024
224db6d
Handle type errors in utils/ipython_magic.py
henrikmidtiby Nov 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions manim/_config/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ def digest_parser(self, parser: configparser.ConfigParser) -> Self:

return self

def digest_args(self, args: argparse.Namespace) -> Self:
def digest_args(self, args: argparse.Namespace | list[str]) -> Self:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

args cannot be a list[str] in this case. args is an object with multiple attributes which a list doesn't have, as it can be seen in lines such as:

if args.file.suffix == ".cfg":
	args.config_file = args.file
Suggested change
def digest_args(self, args: argparse.Namespace | list[str]) -> Self:
def digest_args(self, args: argparse.Namespace) -> Self:

"""Process the config options present in CLI arguments.
Parameters
Expand Down Expand Up @@ -1395,7 +1395,7 @@ def renderer(self, value: str | RendererType) -> None:
self._set_from_enum("renderer", renderer, RendererType)

@property
def media_dir(self) -> str:
def media_dir(self) -> str | Path:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 29 in manim/cli/render/output_options.py describes how the --media_dir option is parsed:

    option(
        "--media_dir",
        type=Path(),
        default=None,
        help="Path to store rendered videos and latex.",
    ),

which means it will be always parsed as a Path. Thus, this property should return and receive a Path.

Technically, all those properties can return None, because ManimConfig.__init__() initializes the internal ._d dictionary whose values are all None in the beginning, before digesting any arguments. This makes the config even harder to use with type checking. Currently, the Manim configuration is very convoluted and I don't think that MyPy errors can stop being ignored right now in mypy.ini, but there have been multiple attempts to refactor and clean the config, so I would wait for those PRs. In the meantime, I would say that simply returning Path is fine.

Suggested change
def media_dir(self) -> str | Path:
def media_dir(self) -> Path:

"""Main output directory. See :meth:`ManimConfig.get_dir`."""
return self._d["media_dir"]

Expand Down
5 changes: 4 additions & 1 deletion manim/mobject/text/numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
__all__ = ["DecimalNumber", "Integer", "Variable"]

from collections.abc import Sequence
from typing import Any

import numpy as np

Expand Down Expand Up @@ -327,7 +328,9 @@ def construct(self):
self.add(Integer(number=6.28).set_x(-1.5).set_y(-2).set_color(YELLOW).scale(1.4))
"""

def __init__(self, number=0, num_decimal_places=0, **kwargs):
def __init__(
self, number: float = 0, num_decimal_places: int = 0, **kwargs: Any
) -> None:
super().__init__(number=number, num_decimal_places=num_decimal_places, **kwargs)

def get_value(self):
Expand Down
3 changes: 2 additions & 1 deletion manim/renderer/cairo_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import typing

import numpy as np
import numpy.typing as npt

from manim.utils.hashing import get_hash_from_play_call

Expand Down Expand Up @@ -160,7 +161,7 @@ def render(self, scene, time, moving_mobjects):
self.update_frame(scene, moving_mobjects)
self.add_frame(self.get_frame())

def get_frame(self):
def get_frame(self) -> npt.NDArray:
Copy link
Contributor

@chopan050 chopan050 Dec 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a type alias in manim.utils.typing for this specific case:

Suggested change
def get_frame(self) -> npt.NDArray:
def get_frame(self) -> PixelArray:

"""
Gets the current frame as NumPy array.

Expand Down
2 changes: 1 addition & 1 deletion manim/scene/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from ..utils import opengl, space_ops
from ..utils.exceptions import EndSceneEarlyException, RerunSceneException
from ..utils.family import extract_mobject_family_members
from ..utils.family_ops import restructure_list_to_exclude_certain_family_members

Check failure

Code scanning / CodeQL

Module-level cyclic import Error

'restructure_list_to_exclude_certain_family_members' may not be defined if module
manim.utils.family_ops
is imported before module
manim.scene.scene
, as the
definition
of restructure_list_to_exclude_certain_family_members occurs after the cyclic
import
of manim.scene.scene.
from ..utils.file_ops import open_media_file
from ..utils.iterables import list_difference_update, list_update

Expand Down Expand Up @@ -486,7 +486,7 @@
self.moving_mobjects += mobjects
return self

def add_mobjects_from_animations(self, animations):
def add_mobjects_from_animations(self, animations: list) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When a parameter or return type is a list, you can (and should) specify its contents explicitly. In this case:

Suggested change
def add_mobjects_from_animations(self, animations: list) -> None:
def add_mobjects_from_animations(self, animations: list[Animation]) -> None:

curr_mobjects = self.get_mobject_family_members()
for animation in animations:
if animation.is_introducer():
Expand Down
17 changes: 12 additions & 5 deletions manim/scene/scene_file_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pydub import AudioSegment

from manim import __version__
from manim.typing import PixelArray
from manim.typing import PixelArray, StrPath

from .. import config, logger
from .._config.logger_utils import set_file_logger
Expand All @@ -34,7 +34,7 @@
modify_atime,
write_to_movie,
)
from ..utils.sounds import get_full_sound_file_path

Check failure

Code scanning / CodeQL

Module-level cyclic import Error

'get_full_sound_file_path' may not be defined if module
manim.utils.sounds
is imported before module
manim.scene.scene_file_writer
, as the
definition
of get_full_sound_file_path occurs after the cyclic
import
of manim.scene.scene_file_writer.
from .section import DefaultSectionType, Section

if TYPE_CHECKING:
Expand Down Expand Up @@ -104,7 +104,12 @@

force_output_as_scene_name = False

def __init__(self, renderer, scene_name, **kwargs):
def __init__(
self,
renderer: Any,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
renderer: Any,
renderer: CairoRenderer | OpenGLRenderer,

In this way, the type checker can detect the .num_plays attribute which is referenced twice in this file.

scene_name: StrPath,
**kwargs: Any,
) -> None:
self.renderer = renderer
self.init_output_directories(scene_name)
self.init_audio()
Expand All @@ -118,7 +123,7 @@
name="autocreated", type_=DefaultSectionType.NORMAL, skip_animations=False
)

def init_output_directories(self, scene_name):
def init_output_directories(self, scene_name: StrPath):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def init_output_directories(self, scene_name: StrPath):
def init_output_directories(self, scene_name: StrPath) -> None:

"""Initialise output directories.

Notes
Expand Down Expand Up @@ -378,7 +383,9 @@
self.add_audio_segment(new_segment, time, **kwargs)

# Writers
def begin_animation(self, allow_write: bool = False, file_path=None):
def begin_animation(
self, allow_write: bool = False, file_path: StrPath | None = None
) -> None:
"""
Used internally by manim to stream the animation to FFMPEG for
displaying or writing to a file.
Expand All @@ -391,7 +398,7 @@
if write_to_movie() and allow_write:
self.open_partial_movie_stream(file_path=file_path)

def end_animation(self, allow_write: bool = False):
def end_animation(self, allow_write: bool = False) -> None:
"""
Internally used by Manim to stop streaming to
FFMPEG gracefully.
Expand Down
107 changes: 69 additions & 38 deletions manim/utils/bezier.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
BezierPoints,
BezierPoints_Array,
ColVector,
InternalPoint3D,
InternalPoint3D_Array,
MatrixMN,
Point3D,
Point3D_Array,
Expand All @@ -54,10 +56,12 @@
@overload
def bezier(
points: Sequence[Point3D_Array],
) -> Callable[[float | ColVector], Point3D_Array]: ...
) -> Callable[[float | ColVector], Point3D | Point3D_Array]: ...
Dismissed Show dismissed Hide dismissed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, when points is a sequence of Point3D_Arrays, the result is always a Point3D_Array.

Suggested change
) -> Callable[[float | ColVector], Point3D | Point3D_Array]: ...
) -> Callable[[float | ColVector], Point3D_Array]: ...



def bezier(points):
def bezier(
points: Point3D_Array | Sequence[Point3D_Array],
) -> Callable[[float | ColVector], Point3D | Point3D_Array]:
"""Classic implementation of a Bézier curve.

Parameters
Expand Down Expand Up @@ -111,21 +115,21 @@

if degree == 0:

def zero_bezier(t):
def zero_bezier(t: float | ColVector) -> Point3D | Point3D_Array:
return np.ones_like(t) * P[0]

return zero_bezier

if degree == 1:

def linear_bezier(t):
def linear_bezier(t: float | ColVector) -> Point3D | Point3D_Array:
return P[0] + t * (P[1] - P[0])

return linear_bezier

if degree == 2:

def quadratic_bezier(t):
def quadratic_bezier(t: float | ColVector) -> Point3D | Point3D_Array:
t2 = t * t
mt = 1 - t
mt2 = mt * mt
Expand All @@ -135,7 +139,7 @@

if degree == 3:

def cubic_bezier(t):
def cubic_bezier(t: float | ColVector) -> Point3D | Point3D_Array:
t2 = t * t
t3 = t2 * t
mt = 1 - t
Expand All @@ -145,11 +149,12 @@

return cubic_bezier

def nth_grade_bezier(t):
def nth_grade_bezier(t: float | ColVector) -> Point3D | Point3D_Array:
is_scalar = not isinstance(t, np.ndarray)
if is_scalar:
B = np.empty((1, *P.shape))
else:
assert isinstance(t, np.ndarray)
t = t.reshape(-1, *[1 for dim in P.shape])
B = np.empty((t.shape[0], *P.shape))
B[:] = P
Expand All @@ -162,7 +167,8 @@
# In the end, there shall be the evaluation at t of a single Bezier curve of
# grade d, stored in the first slot of B
if is_scalar:
return B[0, 0]
val: Point3D = B[0, 0]
return val
return B[:, 0]

return nth_grade_bezier
Expand Down Expand Up @@ -700,7 +706,7 @@


# Memos explained in subdivide_bezier docstring
SUBDIVISION_MATRICES = [{} for i in range(4)]
SUBDIVISION_MATRICES: list[dict[int, npt.NDArray]] = [{} for i in range(4)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please prefer the MatrixMN type alias:

Suggested change
SUBDIVISION_MATRICES: list[dict[int, npt.NDArray]] = [{} for i in range(4)]
SUBDIVISION_MATRICES: list[dict[int, MatrixMN]] = [{} for i in range(4)]



def _get_subdivision_matrix(n_points: int, n_divisions: int) -> MatrixMN:
Expand Down Expand Up @@ -812,7 +818,9 @@
return subdivision_matrix


def subdivide_bezier(points: BezierPoints, n_divisions: int) -> Point3D_Array:
def subdivide_bezier(
points: InternalPoint3D_Array, n_divisions: int
) -> InternalPoint3D_Array:
Comment on lines +821 to +823
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the time this PR was made (see explanation below), the previous parameter type was correct: points is an array containing the control points for a Bézier curve, so BezierPoints is more appropiate than (Internal)Point3D_Array.

Actually, the previous return type was imprecise: it should've been Spline instead of Point3D_Array. Spline is a more specific kind of Point3D_Array consisting of consecutive blocks of control points for consecutive Bézier curves which also happen to be connected, forming a continuous curve.

There is a recently merged PR which renames type aliases such as InternalPoint3D (a NumPy array) to Point3D, and Point3D (anything resembling a 3D point) to Point3DLike: #4027

Therefore, since points doesn't have to be a NumPy array (because it is converted to one inside the function), the correct type would now be BezierPointsLike.

Suggested change
def subdivide_bezier(
points: InternalPoint3D_Array, n_divisions: int
) -> InternalPoint3D_Array:
def subdivide_bezier(points: BezierPointsLike, n_divisions: int) -> Spline:

r"""Subdivide a Bézier curve into :math:`n` subcurves which have the same shape.

The points at which the curve is split are located at the
Expand Down Expand Up @@ -1012,14 +1020,22 @@


@overload
def interpolate(start: Point3D, end: Point3D, alpha: float) -> Point3D: ...
def interpolate(
start: InternalPoint3D, end: InternalPoint3D, alpha: float
) -> Point3D: ...
Dismissed Show dismissed Hide dismissed
Comment on lines +1023 to +1025
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When this PR was made, it was correct to type start and end as InternalPoint3D. Actually, the same should have been applied to the return type.

Now, since PR #4027, it's correct to type all of it as Point3D which now always represents a NumPy array, in contrast with the new Point3DLike which can also be a tuple or list of floats.

Suggested change
def interpolate(
start: InternalPoint3D, end: InternalPoint3D, alpha: float
) -> Point3D: ...
def interpolate(start: Point3D, end: Point3D, alpha: float) -> Point3D: ...



@overload
def interpolate(start: Point3D, end: Point3D, alpha: ColVector) -> Point3D_Array: ...
def interpolate(
start: InternalPoint3D, end: InternalPoint3D, alpha: ColVector
) -> Point3D_Array: ...
Dismissed Show dismissed Hide dismissed
Comment on lines +1029 to +1031
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as before:

Suggested change
def interpolate(
start: InternalPoint3D, end: InternalPoint3D, alpha: ColVector
) -> Point3D_Array: ...
def interpolate(start: Point3D, end: Point3D, alpha: ColVector) -> Point3D_Array: ...



def interpolate(start, end, alpha):
def interpolate(
start: float | InternalPoint3D,
end: float | InternalPoint3D,
alpha: float | ColVector,
) -> float | ColVector | Point3D | Point3D_Array:
Comment on lines +1034 to +1038
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def interpolate(
start: float | InternalPoint3D,
end: float | InternalPoint3D,
alpha: float | ColVector,
) -> float | ColVector | Point3D | Point3D_Array:
def interpolate(
start: float | Point3D,
end: float | Point3D,
alpha: float | ColVector,
) -> float | ColVector | Point3D | Point3D_Array:

"""Linearly interpolates between two values ``start`` and ``end``.

Parameters
Expand Down Expand Up @@ -1099,10 +1115,12 @@


@overload
def mid(start: Point3D, end: Point3D) -> Point3D: ...
def mid(start: InternalPoint3D, end: InternalPoint3D) -> Point3D: ...
Dismissed Show dismissed Hide dismissed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def mid(start: InternalPoint3D, end: InternalPoint3D) -> Point3D: ...
def mid(start: Point3D, end: Point3D) -> Point3D: ...



def mid(start: float | Point3D, end: float | Point3D) -> float | Point3D:
def mid(
start: float | InternalPoint3D, end: float | InternalPoint3D
) -> float | Point3D:
Comment on lines +1121 to +1123
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def mid(
start: float | InternalPoint3D, end: float | InternalPoint3D
) -> float | Point3D:
def mid(start: float | Point3D, end: float | Point3D) -> float | Point3D:

"""Returns the midpoint between two values.

Parameters
Expand All @@ -1124,15 +1142,21 @@


@overload
def inverse_interpolate(start: float, end: float, value: Point3D) -> Point3D: ...
def inverse_interpolate(
start: float, end: float, value: InternalPoint3D
) -> InternalPoint3D: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.
Comment on lines +1145 to +1147
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def inverse_interpolate(
start: float, end: float, value: InternalPoint3D
) -> InternalPoint3D: ...
def inverse_interpolate(start: float, end: float, value: Point3D) -> Point3D: ...



@overload
def inverse_interpolate(start: Point3D, end: Point3D, value: Point3D) -> Point3D: ...
def inverse_interpolate(
start: InternalPoint3D, end: InternalPoint3D, value: InternalPoint3D
) -> Point3D: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.
Comment on lines +1151 to +1153
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def inverse_interpolate(
start: InternalPoint3D, end: InternalPoint3D, value: InternalPoint3D
) -> Point3D: ...
def inverse_interpolate(start: Point3D, end: Point3D, value: Point3D) -> Point3D: ...



def inverse_interpolate(
start: float | Point3D, end: float | Point3D, value: float | Point3D
start: float | InternalPoint3D,
end: float | InternalPoint3D,
value: float | InternalPoint3D,
Comment on lines +1157 to +1159
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
start: float | InternalPoint3D,
end: float | InternalPoint3D,
value: float | InternalPoint3D,
start: float | Point3D,
end: float | Point3D,
value: float | Point3D,

) -> float | Point3D:
"""Perform inverse interpolation to determine the alpha
values that would produce the specified ``value``
Expand Down Expand Up @@ -1186,7 +1210,7 @@
new_end: float,
old_start: float,
old_end: float,
old_value: Point3D,
old_value: InternalPoint3D,
) -> Point3D: ...


Expand All @@ -1195,7 +1219,7 @@
new_end: float,
old_start: float,
old_end: float,
old_value: float | Point3D,
old_value: float | InternalPoint3D,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
old_value: float | InternalPoint3D,
old_value: float | Point3D,

) -> float | Point3D:
"""Interpolate a value from an old range to a new range.

Expand Down Expand Up @@ -1227,7 +1251,7 @@
return interpolate(
new_start,
new_end,
old_alpha, # type: ignore[arg-type]
old_alpha,
)


Expand Down Expand Up @@ -1263,7 +1287,8 @@
# they can only be an interpolation of these two anchors with alphas
# 1/3 and 2/3, which will draw a straight line between the anchors.
if n_anchors == 2:
return interpolate(anchors[0], anchors[1], np.array([[1 / 3], [2 / 3]]))
val = interpolate(anchors[0], anchors[1], np.array([[1 / 3], [2 / 3]]))
return (val[0], val[1])

# Handle different cases depending on whether the points form a closed
# curve or not
Expand Down Expand Up @@ -1738,7 +1763,12 @@
) -> QuadraticBezierPoints_Array: ...


def get_quadratic_approximation_of_cubic(a0, h0, h1, a1):
def get_quadratic_approximation_of_cubic(
a0: Point3D | Point3D_Array,
h0: Point3D | Point3D_Array,
h1: Point3D | Point3D_Array,
a1: Point3D | Point3D_Array,
) -> QuadraticBezierPoints_Array:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and the return types of the previous overloads are not correct, because it's not the control points for a single Bézier curve, but rather 2 or 2N curves. It should be QuadraticSpline for the first overload with Point3D, because it's 2 connected quadratic Béziers, and QuadraticBezierPath for the second overload with Point3D_Array, because it's 2N quadratic Béziers which are not necessarily connected, depending on the given points.

Suggested change
) -> QuadraticBezierPoints_Array:
) -> QuadraticSpline | QuadraticBezierPath:

r"""If ``a0``, ``h0``, ``h1`` and ``a1`` are the control points of a cubic
Bézier curve, approximate the curve with two quadratic Bézier curves and
return an array of 6 points, where the first 3 points represent the first
Expand Down Expand Up @@ -1842,33 +1872,33 @@
If ``a0``, ``h0``, ``h1`` and ``a1`` have different dimensions, or
if their number of dimensions is not 1 or 2.
"""
a0 = np.asarray(a0)
h0 = np.asarray(h0)
h1 = np.asarray(h1)
a1 = np.asarray(a1)

if all(arr.ndim == 1 for arr in (a0, h0, h1, a1)):
num_curves, dim = 1, a0.shape[0]
elif all(arr.ndim == 2 for arr in (a0, h0, h1, a1)):
num_curves, dim = a0.shape
a0c = np.asarray(a0)
h0c = np.asarray(h0)
h1c = np.asarray(h1)
a1c = np.asarray(a1)

if all(arr.ndim == 1 for arr in (a0c, h0c, h1c, a1c)):
num_curves, dim = 1, a0c.shape[0]
elif all(arr.ndim == 2 for arr in (a0c, h0c, h1c, a1c)):
num_curves, dim = a0c.shape
else:
raise ValueError("All arguments must be Point3D or Point3D_Array.")

m0 = 0.25 * (3 * h0 + a0)
m1 = 0.25 * (3 * h1 + a1)
m0 = 0.25 * (3 * h0c + a0c)
m1 = 0.25 * (3 * h1c + a1c)
k = 0.5 * (m0 + m1)

result = np.empty((6 * num_curves, dim))
result[0::6] = a0
result[0::6] = a0c
result[1::6] = m0
result[2::6] = k
result[3::6] = k
result[4::6] = m1
result[5::6] = a1
result[5::6] = a1c
return result


def is_closed(points: Point3D_Array) -> bool:
def is_closed(points: InternalPoint3D_Array) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def is_closed(points: InternalPoint3D_Array) -> bool:
def is_closed(points: Point3D_Array) -> bool:

"""Returns ``True`` if the spline given by ``points`` is closed, by
checking if its first and last points are close to each other, or``False``
otherwise.
Expand Down Expand Up @@ -1938,7 +1968,8 @@
return False
if abs(end[1] - start[1]) > tolerance[1]:
return False
return abs(end[2] - start[2]) <= tolerance[2]
val: bool = abs(end[2] - start[2]) <= tolerance[2]
return val


def proportions_along_bezier_curve_for_point(
Expand Down
Loading
Loading