Skip to content

Commit

Permalink
Fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
stellaraccident committed Feb 5, 2024
1 parent 9b1ac7c commit 4692c47
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 50 deletions.
29 changes: 0 additions & 29 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,32 +47,3 @@ jobs:
- name: Run tests
run: |
pytest -n 4 core/tests/
black:
strategy:
matrix:
version: [3.11]
os: [ubuntu-latest]
runs-on: ${{matrix.os}}
steps:
- name: Checking out repository
uses: actions/checkout@8f4b7f84864484a7bf31766abe9204da3cbe65b3 # v3.5.0
- name: Setting up python
uses: actions/setup-python@d27e3f3d7c64b4bbf8e4abfb9b63b83e846e0435 # v4.5.0
- name: Fetching Base Branch
# We have to explicitly fetch the base branch as well
run: git fetch --no-tags --prune --depth=1 origin "${GITHUB_BASE_REF?}:${GITHUB_BASE_REF?}"
- name: Install black
run: |
python3 -m pip install black==23.3
- name: Check if modified files are formatted
run: |
# The filter lowercase `d` means to exclude deleted files.
git diff "${GITHUB_BASE_REF?}" --name-only --diff-filter=d \
-- '*.py' \
| xargs --no-run-if-empty black --check --diff --verbose
- name: Instructions for fixing the above linting errors
if: failure()
run: |
printf "You can fix formatting by running 'black' on the modified python files:\n"
printf " git diff ${GITHUB_BASE_REF?} --name-only -- '*.py' ':!third_party' | xargs black\n"
9 changes: 7 additions & 2 deletions core/examples/eager_mlp/mlp_eager_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
import torchvision.transforms as transforms
import torchvision.datasets as datasets

torch._dynamo.config.dynamic_shapes = False # TODO: https://github.com/nod-ai/SHARK-Turbine/issues/93
torch._dynamo.config.dynamic_shapes = (
False # TODO: https://github.com/nod-ai/SHARK-Turbine/issues/93
)


class MNISTDataLoader:
Expand All @@ -39,7 +41,10 @@ def get_train_loader(self):

def get_test_loader(self):
return DataLoader(
dataset=self.mnist_testset, batch_size=self.batch_size, shuffle=False, drop_last=True,
dataset=self.mnist_testset,
batch_size=self.batch_size,
shuffle=False,
drop_last=True,
)


Expand Down
10 changes: 2 additions & 8 deletions core/shark_turbine/aot/support/procedural/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,7 @@ def track(self, module_builder: ModuleBuilder, export_namespace: str) -> Any:
continue
elif isinstance(value, AbstractTensor):
global_type = value.get_ir_type(module_builder)
(
actual_symbol_name,
global_op,
) = module_builder.create_typed_global(
(actual_symbol_name, global_op,) = module_builder.create_typed_global(
f"_{fq_name}",
global_type,
attrs=self._attrs,
Expand All @@ -163,10 +160,7 @@ def track(self, module_builder: ModuleBuilder, export_namespace: str) -> Any:
continue
elif isinstance(value, AbstractScalar):
global_type = value.get_ir_type(module_builder)
(
actual_symbol_name,
global_op,
) = module_builder.create_typed_global(
(actual_symbol_name, global_op,) = module_builder.create_typed_global(
f"_{fq_name}",
global_type,
attrs=self._attrs,
Expand Down
16 changes: 5 additions & 11 deletions core/shark_turbine/transforms/quantization/mm_group_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@ def match(self, op: Operation):
m=m,
n=n,
k=k,
element_type=self.builder.get_tensor_element_type(
op.operands[0].type
),
element_type=self.builder.get_tensor_element_type(op.operands[0].type),
)


Expand Down Expand Up @@ -134,9 +132,7 @@ def __init__(self, root_op: Operation, *, group_size: int = 128):

def run(self):
globals = self.globals
mms = match_children(
self.funcs, TransposedMMMatcher(globals, self.builder)
)
mms = match_children(self.funcs, TransposedMMMatcher(globals, self.builder))

for mr in mms:
if mr.k is None or mr.n is None:
Expand Down Expand Up @@ -165,12 +161,10 @@ def rewrite(self, mr: TransposedMMResult):
element_type=mr.element_type,
)

inline_module = Operation.parse(
inline_module_asm, context=self.context
inline_module = Operation.parse(inline_module_asm, context=self.context)
actual_callee_name = self.merge_module(inline_module).translate_symbol(
"compute_mm_group_quant"
)
actual_callee_name = self.merge_module(
inline_module
).translate_symbol("compute_mm_group_quant")
with InsertionPoint(mr.op), mr.op.location:
results = self.builder.call_native(
actual_callee_name, [mr.op.result.type], mr.op.operands[0]
Expand Down

0 comments on commit 4692c47

Please sign in to comment.