diff --git a/unyt/dask_array.py b/unyt/dask_array.py
index ef51960d..6d4cc0e3 100644
--- a/unyt/dask_array.py
+++ b/unyt/dask_array.py
@@ -650,10 +650,11 @@ def reduce_with_units(dask_func, unyt_dask_in, *args, **kwargs):
 
     Examples
     --------
-    >>> from unyt import dask_array
-    >>> a = dask_array.dask.array.ones((10000,), chunks=(100,))
-    >>> a = dask_array.unyt_from_dask(a, 'm')
-    >>> b = dask_array.reduce_with_units(dask_array.dask.array.median, a, axis=0)
+    >>> import dask.array
+    >>> from unyt.dask_array import unyt_from_dask, reduce_with_units
+    >>> a = dask.array.ones((10000,), chunks=(100,))
+    >>> a = unyt_from_dask(a, 'm')
+    >>> b = reduce_with_units(dask.array.median, a, axis=0)
     >>> b.compute()
     unyt_quantity(1., 'm')