Skip to content

Commit

Permalink
Avoid using match in python code
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 631280405
  • Loading branch information
allight authored and copybara-github committed May 7, 2024
1 parent 5cb393f commit d94a1b0
Showing 1 changed file with 62 additions and 49 deletions.
111 changes: 62 additions & 49 deletions xls/jit/jit_wrapper_generator_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from collections.abc import Sequence
import dataclasses
import subprocess
from typing import Optional

from absl import app
from absl import flags
Expand Down Expand Up @@ -90,7 +91,7 @@ class XlsParam:
name: str
packed_type: str
unpacked_type: str
specialized_type: str | None
specialized_type: Optional[str]

@property
def value_arg(self):
Expand Down Expand Up @@ -135,41 +136,46 @@ def params_and_result(self):


def to_packed(t: type_pb2.TypeProto) -> str:
match t.type_enum:
case type_pb2.TypeProto.BITS:
return f"xls::PackedBitsView<{t.bit_count}>"
case type_pb2.TypeProto.TUPLE:
inner = ", ".join(to_packed(e) for e in t.tuple_elements)
return f"xls::PackedTupleView<{inner}>"
case type_pb2.TypeProto.ARRAY:
return (
f"xls::PackedArrayView<{to_packed(t.array_element)}, {t.array_size}>"
)
case _:
raise app.UsageError(
"Incompatible with argument of type:"
f" {type_pb2.TypeProto.TypeEnum.Name(t.type_enum)}"
)
the_type = t.type_enum
if the_type == type_pb2.TypeProto.BITS:
return f"xls::PackedBitsView<{t.bit_count}>"
elif the_type == type_pb2.TypeProto.TUPLE:
inner = ", ".join(to_packed(e) for e in t.tuple_elements)
return f"xls::PackedTupleView<{inner}>"
elif the_type == type_pb2.TypeProto.ARRAY:
return f"xls::PackedArrayView<{to_packed(t.array_element)}, {t.array_size}>"
raise app.UsageError(
"Incompatible with argument of type:"
f" {type_pb2.TypeProto.TypeEnum.Name(t.type_enum)}"
)


def to_unpacked(t: type_pb2.TypeProto, mutable: bool = False) -> str:
"""Get the unpacked c++ view type.
Args:
t: The xls type
mutable: Whether the type is mutable.
Returns:
the C++ unpacked view type
"""
mutable_str = "Mutable" if mutable else ""
match t.type_enum:
case type_pb2.TypeProto.BITS:
return f"xls::{mutable_str}BitsView<{t.bit_count}>"
case type_pb2.TypeProto.TUPLE:
inner = ", ".join(to_unpacked(e, mutable) for e in t.tuple_elements)
return f"xls::{mutable_str}TupleView<{inner}>"
case type_pb2.TypeProto.ARRAY:
return (
f"xls::{mutable_str}ArrayView<{to_unpacked(t.array_element, mutable)},"
f" {t.array_size}>"
)
case _:
raise app.UsageError(
"Incompatible with argument of type:"
f" {type_pb2.TypeProto.TypeEnum.Name(t.type_enum)}"
)
the_type = t.type_enum
if the_type == type_pb2.TypeProto.BITS:
return f"xls::{mutable_str}BitsView<{t.bit_count}>"
elif the_type == type_pb2.TypeProto.TUPLE:
inner = ", ".join(to_unpacked(e, mutable) for e in t.tuple_elements)
return f"xls::{mutable_str}TupleView<{inner}>"
elif the_type == type_pb2.TypeProto.ARRAY:
return (
f"xls::{mutable_str}ArrayView<{to_unpacked(t.array_element, mutable)},"
f" {t.array_size}>"
)
raise app.UsageError(
"Incompatible with argument of type:"
f" {type_pb2.TypeProto.TypeEnum.Name(t.type_enum)}"
)


def is_floating_point(
Expand All @@ -195,23 +201,30 @@ def is_float_tuple(t: type_pb2.TypeProto) -> bool:
return is_floating_point(t, 8, 23)


def to_specialized(t: type_pb2.TypeProto) -> str | None:
match t.type_enum:
case type_pb2.TypeProto.BITS:
if t.bit_count <= 8:
return "uint8_t"
elif t.bit_count <= 16:
return "uint16_t"
elif t.bit_count <= 32:
return "uint32_t"
elif t.bit_count <= 64:
return "uint64_t"
case type_pb2.TypeProto.TUPLE if is_double_tuple(t):
return "double"
case type_pb2.TypeProto.TUPLE if is_float_tuple(t):
return "float"
case _:
return None
def to_specialized(t: type_pb2.TypeProto) -> Optional[str]:
"""Get the specialized c++ type.
Args:
t: The xls type
Returns:
the C++ type
"""
the_type = t.type_enum
if the_type == type_pb2.TypeProto.BITS:
if t.bit_count <= 8:
return "uint8_t"
elif t.bit_count <= 16:
return "uint16_t"
elif t.bit_count <= 32:
return "uint32_t"
elif t.bit_count <= 64:
return "uint64_t"
elif is_double_tuple(t):
return "double"
elif is_float_tuple(t):
return "float"
return None


def to_param(p: ir_interface_pb2.PackageInterfaceProto.NamedValue) -> XlsParam:
Expand Down

0 comments on commit d94a1b0

Please sign in to comment.