Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pallas] Add pipeline mode to pltpu #25852

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ class BlockSpec:
index_map: Callable[..., Any] | None = None
memory_space: Any | None = dataclasses.field(kw_only=True, default=None)
indexing_mode: IndexingMode = dataclasses.field(kw_only=True, default=blocked)
pipeline_mode: Any | None = None

def to_block_mapping(
self,
Expand Down Expand Up @@ -469,6 +470,7 @@ def to_block_mapping(
array_aval_shape, array_aval.dtype
),
origin=origin,
pipeline_mode=self.pipeline_mode,
)
mapping.check_invariants()
return mapping
Expand Down Expand Up @@ -510,6 +512,7 @@ class BlockMapping:
array_shape_dtype: jax.ShapeDtypeStruct # The whole array
origin: OriginStr
transforms: Sequence[MemoryRefTransform] = ()
pipeline_mode: Any | None = None

def check_invariants(self) -> None:
if not config.enable_checks.value: return
Expand Down
9 changes: 8 additions & 1 deletion jax/_src/pallas/mosaic/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,16 @@ class TPUCompilerParams(pallas_core.CompilerParams):
internal_scratch_in_bytes: int | None = None
serialization_format: int = 1
device_type: str | None = None

replace = dataclasses.replace


class PipelineMode(enum.Enum):
SYNCHRONOUS = "synchronous"
DOUBLE_BUFFERED = "double_buffered"

def __str__(self) -> str:
return self.value

class TPUMemorySpace(enum.Enum):
ANY = "any" # TODO(b/368401328): Remove this and just use pl.ANY.
VMEM = "vmem"
Expand Down
4 changes: 4 additions & 0 deletions jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,10 @@ def lower_jaxpr_to_module(
block_params["window_kind"] = ir.Attribute.parse(
f"#tpu.element_window<{pad_low},{pad_high}>"
)
if bm.pipeline_mode is not None:
block_params["pipeline_mode"] = ir.Attribute.parse(
f"#tpu.pipeline_mode<{bm.pipeline_mode.value}>"
)
window_params.append(ir.DictAttr.get(block_params))
m.body.append(mlir_func)
sym_tab.insert(mlir_func)
Expand Down
14 changes: 8 additions & 6 deletions jax/experimental/pallas/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,19 @@

"""Mosaic-specific Pallas APIs."""

import types

from jax._src.pallas.mosaic import core as core
from jax._src.pallas.mosaic.core import _ENABLE_RUNTIME_ASSERT as enable_runtime_assert # noqa: F401
from jax._src.pallas.mosaic.core import create_tensorcore_mesh as create_tensorcore_mesh
from jax._src.pallas.mosaic.core import dma_semaphore as dma_semaphore
from jax._src.pallas.mosaic.core import PipelineMode as PipelineMode
from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec as PrefetchScalarGridSpec
from jax._src.pallas.mosaic.core import runtime_assert_enabled as runtime_assert_enabled
from jax._src.pallas.mosaic.core import semaphore as semaphore
from jax._src.pallas.mosaic.core import SemaphoreType as SemaphoreType
from jax._src.pallas.mosaic.core import TPUMemorySpace as TPUMemorySpace
from jax._src.pallas.mosaic.core import TPUCompilerParams as TPUCompilerParams
from jax._src.pallas.mosaic.core import runtime_assert_enabled as runtime_assert_enabled
from jax._src.pallas.mosaic.core import _ENABLE_RUNTIME_ASSERT as enable_runtime_assert # noqa: F401
from jax._src.pallas.mosaic.core import TPUMemorySpace as TPUMemorySpace
from jax._src.pallas.mosaic.lowering import LoweringException as LoweringException
from jax._src.pallas.mosaic.pipeline import ARBITRARY as ARBITRARY
from jax._src.pallas.mosaic.pipeline import BufferedRef as BufferedRef
Expand All @@ -49,12 +52,11 @@
from jax._src.pallas.mosaic.primitives import semaphore_signal as semaphore_signal
from jax._src.pallas.mosaic.primitives import semaphore_wait as semaphore_wait
from jax._src.pallas.mosaic.random import to_pallas_key as to_pallas_key

import types
from jax._src.pallas.mosaic.verification import assume
from jax._src.pallas.mosaic.verification import define_model
from jax._src.pallas.mosaic.verification import pretend
from jax._src.pallas.mosaic.verification import skip
from jax._src.pallas.mosaic.verification import define_model

verification = types.SimpleNamespace(
assume=assume, pretend=pretend, skip=skip, define_model=define_model
)
Expand Down
12 changes: 12 additions & 0 deletions jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,18 @@ def TPU_CoreTypeEnum : EnumAttr<TPU_Dialect, TPU_CoreType, "core_type"> {
let assemblyFormat = "`<` $value `>`";
}

def TPU_PipelineMode : I32EnumAttr<"PipelineMode", "Pipeline mode", [
I32EnumAttrCase<"kSynchronous", 0, "synchronous">,
I32EnumAttrCase<"kDoubleBuffered", 1, "double_buffered">
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::tpu";
}

def TPU_PipelineModeEnum : EnumAttr<TPU_Dialect, TPU_PipelineMode, "pipeline_mode"> {
let assemblyFormat = "`<` $value `>`";
}

def TPU_SemaphoreType : TPU_Type<"Semaphore", "semaphore", [MemRefElementTypeInterface]>;
def TPU_DMASemaphoreType : TPU_Type<"DMASemaphore", "dma_semaphore", [MemRefElementTypeInterface]>;
def TPU_SomeSemaphoreType : AnyTypeOf<[TPU_SemaphoreType, TPU_DMASemaphoreType]>;
Expand Down
58 changes: 58 additions & 0 deletions tests/pallas/tpu_pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,64 @@ def setUp(self):
def pallas_call(self, *args, **kwargs):
return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET)

class TPUPipelineModeTest(PallasBaseTest):

@parameterized.parameters(
(pltpu.PipelineMode.DOUBLE_BUFFERED, pltpu.PipelineMode.DOUBLE_BUFFERED),
(pltpu.PipelineMode.DOUBLE_BUFFERED, pltpu.PipelineMode.SYNCHRONOUS),
(pltpu.PipelineMode.SYNCHRONOUS, pltpu.PipelineMode.SYNCHRONOUS))
def test_two_input_vadd(self, x_pmode, y_pmode):
def body(x_ref, y_ref, o_ref):
x = x_ref[:]
y = y_ref[:]
o_ref[:] = x + y

size_in_vregs = 128
data_size = size_in_vregs * 1024
block_size = 1024

x = jnp.arange(data_size, dtype=jnp.float32)
y = jnp.arange(data_size, dtype=jnp.float32)
in_specs = [
pl.BlockSpec((block_size,), lambda i: i, pipeline_mode=pmode)
for pmode in [x_pmode, y_pmode]
]
out_specs = pl.BlockSpec((block_size,), lambda i: i)

@jax.jit
def vadd(x, y):
return self.pallas_call(
body,
out_shape=jax.ShapeDtypeStruct(x.shape, jnp.float32),
in_specs=in_specs,
out_specs=out_specs,
grid=data_size // block_size,
)(x, y)

compiled = (
vadd.lower(
jax.ShapeDtypeStruct(x.shape, x.dtype),
jax.ShapeDtypeStruct(y.shape, y.dtype),
)
.compile()
.as_text()
)
pattern = (
r'"used_scoped_memory_configs":\[\{"memory_space":"1",.*?"size":"(\d+)"'
)
expected_vmem_usage = (
block_size
* 4
* (
2
+ (2 if x_pmode == pltpu.PipelineMode.DOUBLE_BUFFERED else 1)
+ (2 if y_pmode == pltpu.PipelineMode.DOUBLE_BUFFERED else 1)
)
)
vmem_usage = int(re.search(pattern, compiled).group(1))
self.assertEqual(vmem_usage, expected_vmem_usage)
z = vadd(x, y)
np.testing.assert_allclose(z, x + y)

class PallasCallScalarPrefetchTest(PallasBaseTest):
def test_trivial_scalar_prefetch(self):
Expand Down
Loading