Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v3.4 Release notes and removes NumPy kwargs #239

Merged
merged 8 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion devtools/conda-envs/full-environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies:

# Testing
- codecov
- mypy
- mypy ==1.11*
- pytest
- pytest-cov
- ruff ==0.5.*
4 changes: 2 additions & 2 deletions devtools/conda-envs/min-deps-environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ dependencies:

# Testing
- codecov
- mypy
- mypy ==1.11*
- pytest
- pytest-cov
- ruff ==0.5.*
- ruff ==0.6.*
2 changes: 1 addition & 1 deletion devtools/conda-envs/min-ver-environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ dependencies:

# Testing
- codecov
- mypy
- mypy ==1.11*
- pytest
- pytest-cov
- ruff ==0.5.*
2 changes: 1 addition & 1 deletion devtools/conda-envs/torch-only-environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies:

# Testing
- codecov
- mypy
- mypy ==1.11*
- pytest
- pytest-cov
- ruff ==0.5.*
29 changes: 29 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,35 @@
Changelog
=========

## 3.4.0 / 2024-09-XX

NumPy has been removed from `opt_einsum` as a dependancy allowing for more flexible installs.
dgasmith marked this conversation as resolved.
Show resolved Hide resolved

**New Features**

- [\#160](https://github.com/dgasmith/opt_einsum/pull/160) Migrates docs to MkDocs Material and GitHub pages hosting.
- [\#161](https://github.com/dgasmith/opt_einsum/pull/161) Adds Python type annotations to the code base.
- [\#204](https://github.com/dgasmith/opt_einsum/pull/204) Removes NumPy as a hard dependancy.
dgasmith marked this conversation as resolved.
Show resolved Hide resolved

**Enhancements**

- [\#154](https://github.com/dgasmith/opt_einsum/pull/154) Prevents an infinite recursion error when the `memory_limit` was set very low for the `dp` algorithm.
- [\#155](https://github.com/dgasmith/opt_einsum/pull/155) Adds flake8 spell check to the doc strings
- [\#159](https://github.com/dgasmith/opt_einsum/pull/159) Migrates to GitHub actions for CI.
- [\#174](https://github.com/dgasmith/opt_einsum/pull/174) Prevents double contracts of floats in dynamic paths.
- [\#196](https://github.com/dgasmith/opt_einsum/pull/196) Allows `backend=None` which is equivalent to `backend='auto'`
- [\#208](https://github.com/dgasmith/opt_einsum/pull/208) Switches to `ConfigParser` insetad of `SafeConfigParser` for Python 3.12 compatability.
- [\#228](https://github.com/dgasmith/opt_einsum/pull/228) `backend='jaxlib'` is now an alias for the `jax` library
- [\#237](https://github.com/dgasmith/opt_einsum/pull/237) Switches to `ruff` for formatting and linting.
- [\#238](https://github.com/dgasmith/opt_einsum/pull/238) Removes `numpy`-specific keyword args from being explicitly defined in `contract` and uses `**kwargs` instead.

**Bug Fixes**

- [\#195](https://github.com/dgasmith/opt_einsum/pull/195) Fixes a bug where `dp` would not work for scalar-only contractions.
- [\#200](https://github.com/dgasmith/opt_einsum/pull/200) Fixes a bug where `parse_einsum_input` would not correctly respect shape-only contractions.
- [\#222](https://github.com/dgasmith/opt_einsum/pull/222) Fixes an erorr in `parse_einsum_input` where an output subscript specified multiple times was not correctly caught.
- [\#229](https://github.com/dgasmith/opt_einsum/pull/229) Fixes a bug where empty contraction lists in `PathInfo` would cause an error.

## 3.3.0 / 2020-07-19

Adds a `object` backend for optimized contractions on arbitrary Python objects.
Expand Down
66 changes: 19 additions & 47 deletions opt_einsum/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@

## Common types

_OrderKACF = Literal[None, "K", "A", "C", "F"]

_Casting = Literal["no", "equiv", "safe", "same_kind", "unsafe"]
_MemoryLimit = Union[None, int, Decimal, Literal["max_input"]]


Expand Down Expand Up @@ -284,7 +281,7 @@ def contract_path(
#> 5 defg,hd->efgh efgh->efgh
```
"""
if optimize is True:
if (optimize is True) or (optimize is None):
optimize = "auto"

# Hidden option, only einsum should call this
Expand Down Expand Up @@ -344,9 +341,11 @@ def contract_path(
naive_cost = helpers.flop_count(indices, inner_product, num_ops, size_dict)

# Compute the path
if not isinstance(optimize, (str, paths.PathOptimizer)):
if optimize is False:
path_tuple: PathType = [tuple(range(num_ops))]
elif not isinstance(optimize, (str, paths.PathOptimizer)):
# Custom path supplied
path_tuple: PathType = optimize # type: ignore
path_tuple = optimize # type: ignore
elif num_ops <= 2:
# Nothing to be optimized
path_tuple = [tuple(range(num_ops))]
Expand Down Expand Up @@ -479,9 +478,6 @@ def contract(
subscripts: str,
*operands: ArrayType,
out: ArrayType = ...,
dtype: Any = ...,
order: _OrderKACF = ...,
casting: _Casting = ...,
use_blas: bool = ...,
optimize: OptimizeKind = ...,
memory_limit: _MemoryLimit = ...,
Expand All @@ -495,9 +491,6 @@ def contract(
subscripts: ArrayType,
*operands: Union[ArrayType, Collection[int]],
out: ArrayType = ...,
dtype: Any = ...,
order: _OrderKACF = ...,
casting: _Casting = ...,
use_blas: bool = ...,
optimize: OptimizeKind = ...,
memory_limit: _MemoryLimit = ...,
Expand All @@ -510,9 +503,6 @@ def contract(
subscripts: Union[str, ArrayType],
*operands: Union[ArrayType, Collection[int]],
out: Optional[ArrayType] = None,
dtype: Optional[str] = None,
order: _OrderKACF = "K",
casting: _Casting = "safe",
use_blas: bool = True,
optimize: OptimizeKind = True,
memory_limit: _MemoryLimit = None,
Expand All @@ -527,9 +517,6 @@ def contract(
subscripts: Specifies the subscripts for summation.
*operands: These are the arrays for the operation.
out: A output array in which set the resulting output.
dtype: The dtype of the given contraction, see np.einsum.
order: The order of the resulting contraction, see np.einsum.
casting: The casting procedure for operations of different dtype, see np.einsum.
use_blas: Do you use BLAS for valid operations, may use extra memory for more intermediates.
optimize:- Choose the type of path the contraction will be optimized with
- if a list is given uses this as the path.
Expand All @@ -551,11 +538,12 @@ def contract(
- `'branch-2'` An even more restricted version of 'branch-all' that
only searches the best two options at each step. Scales exponentially
with the number of terms in the contraction.
- `'auto'` Choose the best of the above algorithms whilst aiming to
- `'auto', None, True` Choose the best of the above algorithms whilst aiming to
keep the path finding time below 1ms.
- `'auto-hq'` Aim for a high quality contraction, choosing the best
of the above algorithms whilst aiming to keep the path finding time
below 1sec.
- `False` will not optimize the contraction.

memory_limit:- Give the upper bound of the largest intermediate tensor contract will build.
- None or -1 means there is no limit.
Expand Down Expand Up @@ -586,21 +574,18 @@ def contract(
performed optimally. When NumPy is linked to a threaded BLAS, potential
speedups are on the order of 20-100 for a six core machine.
"""
if optimize is True:
if (optimize is True) or (optimize is None):
optimize = "auto"

operands_list = [subscripts] + list(operands)
einsum_kwargs = {"out": out, "dtype": dtype, "order": order, "casting": casting}

# If no optimization, run pure einsum
if optimize is False:
return _einsum(*operands_list, **einsum_kwargs)
return _einsum(*operands_list, out=out, **kwargs)

# Grab non-einsum kwargs
gen_expression = kwargs.pop("_gen_expression", False)
constants_dict = kwargs.pop("_constants_dict", {})
if len(kwargs):
raise TypeError(f"Did not understand the following kwargs: {kwargs.keys()}")

if gen_expression:
full_str = operands_list[0]
Expand All @@ -613,11 +598,9 @@ def contract(

# check if performing contraction or just building expression
if gen_expression:
return ContractExpression(full_str, contraction_list, constants_dict, dtype=dtype, order=order, casting=casting)
return ContractExpression(full_str, contraction_list, constants_dict, **kwargs)

return _core_contract(
operands, contraction_list, backend=backend, out=out, dtype=dtype, order=order, casting=casting
)
return _core_contract(operands, contraction_list, backend=backend, out=out, **kwargs)


@lru_cache(None)
Expand Down Expand Up @@ -651,9 +634,7 @@ def _core_contract(
backend: Optional[str] = "auto",
evaluate_constants: bool = False,
out: Optional[ArrayType] = None,
dtype: Optional[str] = None,
order: _OrderKACF = "K",
casting: _Casting = "safe",
**kwargs: Any,
) -> ArrayType:
"""Inner loop used to perform an actual contraction given the output
from a ``contract_path(..., einsum_call=True)`` call.
Expand Down Expand Up @@ -703,7 +684,7 @@ def _core_contract(
axes = ((), ())

# Contract!
new_view = _tensordot(*tmp_operands, axes=axes, backend=backend)
new_view = _tensordot(*tmp_operands, axes=axes, backend=backend, **kwargs)

# Build a new view if needed
if (tensor_result != results_index) or handle_out:
Expand All @@ -718,9 +699,7 @@ def _core_contract(
out_kwarg: Union[None, ArrayType] = None
if handle_out:
out_kwarg = out
new_view = _einsum(
einsum_str, *tmp_operands, backend=backend, dtype=dtype, order=order, casting=casting, out=out_kwarg
)
new_view = _einsum(einsum_str, *tmp_operands, backend=backend, out=out_kwarg, **kwargs)

# Append new items and dereference what we can
operands.append(new_view)
Expand Down Expand Up @@ -768,15 +747,11 @@ def __init__(
contraction: str,
contraction_list: ContractionListType,
constants_dict: Dict[int, ArrayType],
dtype: Optional[str] = None,
order: _OrderKACF = "K",
casting: _Casting = "safe",
**kwargs: Any,
):
self.contraction_list = contraction_list
self.dtype = dtype
self.order = order
self.casting = casting
self.contraction = format_const_einsum_str(contraction, constants_dict.keys())
self.contraction_list = contraction_list
self.kwargs = kwargs

# need to know _full_num_args to parse constants with, and num_args to call with
self._full_num_args = contraction.count(",") + 1
Expand Down Expand Up @@ -844,9 +819,7 @@ def _contract(
out=out,
backend=backend,
evaluate_constants=evaluate_constants,
dtype=self.dtype,
order=self.order,
casting=self.casting,
**self.kwargs,
)

def _contract_with_conversion(
Expand Down Expand Up @@ -943,8 +916,7 @@ def __str__(self) -> str:
for i, c in enumerate(self.contraction_list):
s.append(f"\n {i + 1}. ")
s.append(f"'{c[2]}'" + (f" [{c[-1]}]" if c[-1] else ""))
kwargs = {"dtype": self.dtype, "order": self.order, "casting": self.casting}
s.append(f"\neinsum_kwargs={kwargs}")
s.append(f"\neinsum_kwargs={self.kwargs}")
return "".join(s)


Expand Down
9 changes: 7 additions & 2 deletions opt_einsum/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,9 +499,14 @@ def branch(
output: ArrayIndexType,
size_dict: Dict[str, int],
memory_limit: Optional[int] = None,
**optimizer_kwargs: Dict[str, Any],
nbranch: Optional[int] = None,
cutoff_flops_factor: int = 4,
minimize: str = "flops",
cost_fn: str = "memory-removed",
) -> PathType:
optimizer = BranchBound(**optimizer_kwargs) # type: ignore
optimizer = BranchBound(
nbranch=nbranch, cutoff_flops_factor=cutoff_flops_factor, minimize=minimize, cost_fn=cost_fn
)
return optimizer(inputs, output, size_dict, memory_limit)


Expand Down
13 changes: 13 additions & 0 deletions opt_einsum/tests/test_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# NumPy is required for the majority of this file
np = pytest.importorskip("numpy")


tests = [
# Test scalar-like operations
"a,->a",
Expand Down Expand Up @@ -99,6 +100,18 @@
]


@pytest.mark.parametrize("optimize", (True, False, None))
def test_contract_plain_types(optimize: OptimizeKind) -> None:
expr = "ij,jk,kl->il"
ops = [np.random.rand(2, 2), np.random.rand(2, 2), np.random.rand(2, 2)]

path = contract_path(expr, *ops, optimize=optimize)
assert len(path) == 2

result = contract(expr, *ops, optimize=optimize)
assert result.shape == (2, 2)


@pytest.mark.parametrize("string", tests)
@pytest.mark.parametrize("optimize", _PATH_OPTIONS)
def test_compare(optimize: OptimizeKind, string: str) -> None:
Expand Down
9 changes: 1 addition & 8 deletions opt_einsum/tests/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,16 +163,9 @@ def test_value_errors(contract_fn: Any) -> None:
# broadcasting to new dimensions must be enabled explicitly
with pytest.raises(ValueError):
contract_fn("i", np.arange(6).reshape(2, 3))
if contract_fn is contract:
# contract_path does not have an `out` parameter
with pytest.raises(ValueError):
contract_fn("i->i", [[0, 1], [0, 1]], out=np.arange(4).reshape(2, 2))

with pytest.raises(TypeError):
contract_fn("i->i", [[0, 1], [0, 1]], bad_kwarg=True)

with pytest.raises(ValueError):
contract_fn("i->i", [[0, 1], [0, 1]], memory_limit=-1)
contract_fn("ij->ij", [[0, 1], [0, 1]], bad_kwarg=True)


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion opt_einsum/tests/test_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_flop_cost() -> None:


def test_bad_path_option() -> None:
with pytest.raises(TypeError):
with pytest.raises(KeyError):
oe.contract("a,b,c", [1], [2], [3], optimize="optimall", shapes=True) # type: ignore


Expand Down
Loading