Skip to content

Commit

Permalink
All tests pass again with mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
dgasmith committed May 19, 2024
1 parent bf6ba5a commit 96c3b22
Showing 1 changed file with 8 additions and 14 deletions.
22 changes: 8 additions & 14 deletions opt_einsum/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,6 @@ def _choose_memory_arg(memory_limit: _MemoryLimit, size_list: List[int]) -> Opti
return int(memory_limit)


_VALID_CONTRACT_KWARGS = {
"optimize",
"path",
"memory_limit",
"einsum_call",
"use_blas",
"shapes",
}


@overload
def contract_path(
subscripts: str,
Expand Down Expand Up @@ -340,7 +330,7 @@ def contract_path(
path_tuple = [tuple(range(num_ops))]
elif isinstance(optimize, paths.PathOptimizer):
# Custom path optimizer supplied
path_tuple = path(input_sets, output_set, size_dict, memory_arg) # type: ignore
path_tuple = optimize(input_sets, output_set, size_dict, memory_arg) # type: ignore
else:
path_optimizer = paths.get_path_fn(optimize)
path_tuple = path_optimizer(input_sets, output_set, size_dict, memory_arg)
Expand Down Expand Up @@ -419,7 +409,7 @@ def contract_path(


@sharing.einsum_cache_wrap
def _einsum(*operands, **kwargs):
def _einsum(*operands: Any, **kwargs: Any) -> ArrayType:
"""Base einsum, but with pre-parse for valid characters if a string is given."""
fn = backends.get_func("einsum", kwargs.pop("backend", "numpy"))

Expand Down Expand Up @@ -707,8 +697,11 @@ def _core_contract(

else:
# Call einsum
out_kwarg: None | ArrayType = None
if handle_out:
out_kwarg = out
new_view = _einsum(
einsum_str, *tmp_operands, backend=backend, dtype=dtype, order=order, casting=casting, out=out
einsum_str, *tmp_operands, backend=backend, dtype=dtype, order=order, casting=casting, out=out_kwarg
)

# Append new items and dereference what we can
Expand Down Expand Up @@ -910,9 +903,10 @@ def __call__(self, *arrays: ArrayType, **kwargs: Any) -> ArrayType:
if backends.has_backend(backend) and all(infer_backend(x) == "numpy" for x in arrays):
return self._contract_with_conversion(ops, out, backend, evaluate_constants=evaluate_constants)

return self._contract(ops, out, backend, evaluate_constants=evaluate_constants)
return self._contract(ops, out=out, backend=backend, evaluate_constants=evaluate_constants)

except ValueError as err:
raise
original_msg = str(err.args) if err.args else ""
msg = (
"Internal error while evaluating `ContractExpression`. Note that few checks are performed"
Expand Down

0 comments on commit 96c3b22

Please sign in to comment.