Skip to content

Commit

Permalink
#89 Change task & framework parameters to optional in compressor (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
Only-bottle authored Jan 22, 2024
1 parent 10c80d1 commit 4e661c4
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 36 deletions.
14 changes: 5 additions & 9 deletions examples/compressor/automatic_compression.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
from netspresso.compressor import Compressor, Task, Framework
from netspresso.compressor import Compressor


EMAIL = "YOUR_EMAIL"
PASSWORD = "YOUR_PASSWORD"
compressor = Compressor(email=EMAIL, password=PASSWORD)

MODEL_NAME = "test_h5"
TASK = Task.IMAGE_CLASSIFICATION
FRAMEWORK = Framework.TENSORFLOW_KERAS
INPUT_SHAPES = [{"batch": 1, "channel": 3, "dimension": [32, 32]}]
INPUT_MODEL_PATH = "./examples/sample_models/mobilenetv1.h5"
OUTPUT_MODEL_PATH = "./outputs/compressed/mobilenetv1_cifar100_automatic"
MODEL_NAME = "test_graphmodule_pt"
INPUT_SHAPES = [{"batch": 1, "channel": 3, "dimension": [224, 224]}]
INPUT_MODEL_PATH = "./examples/sample_models/graphmodule.pt"
OUTPUT_MODEL_PATH = "./outputs/compressed/graphmodule_automatic_compression"
COMPRESSION_RATIO = 0.5

compressed_model = compressor.automatic_compression(
model_name=MODEL_NAME,
task=TASK,
framework=FRAMEWORK,
input_shapes=INPUT_SHAPES,
input_path=INPUT_MODEL_PATH,
output_path=OUTPUT_MODEL_PATH,
Expand Down
6 changes: 0 additions & 6 deletions examples/compressor/manual_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

from netspresso.compressor import (
Compressor,
Task,
Framework,
CompressionMethod,
Policy,
LayerNorm,
Expand All @@ -18,14 +16,10 @@

# Upload Model
UPLOAD_MODEL_NAME = "test_pt"
TASK = Task.IMAGE_CLASSIFICATION
FRAMEWORK = Framework.PYTORCH
UPLOAD_MODEL_PATH = "./examples/sample_models/graphmodule.pt"
INPUT_SHAPES = [{"batch": 1, "channel": 3, "dimension": [224, 224]}]
model = compressor.upload_model(
model_name=UPLOAD_MODEL_NAME,
task=TASK,
framework=FRAMEWORK,
file_path=UPLOAD_MODEL_PATH,
input_shapes=INPUT_SHAPES,
)
Expand Down
6 changes: 0 additions & 6 deletions examples/compressor/recommendation_compression.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from netspresso.compressor import (
Compressor,
Task,
Framework,
CompressionMethod,
RecommendationMethod,
)
Expand All @@ -13,8 +11,6 @@

# Upload Model
MODEL_NAME = "test_pt"
TASK = Task.IMAGE_CLASSIFICATION
FRAMEWORK = Framework.PYTORCH
INPUT_MODEL_PATH = "./examples/sample_models/graphmodule.pt"
OUTPUT_MODEL_PATH = "./outputs/compressed/graphmodule_recommend"
INPUT_SHAPES = [{"batch": 1, "channel": 3, "dimension": [224, 224]}]
Expand All @@ -24,8 +20,6 @@

compressed_model = compressor.recommendation_compression(
model_name=MODEL_NAME,
task=TASK,
framework=FRAMEWORK,
compression_method=COMPRESSION_METHOD,
recommendation_method=RECOMMENDATION_METHOD,
recommendation_ratio=RECOMMENDATION_RATIO,
Expand Down
15 changes: 6 additions & 9 deletions netspresso/clients/compressor/schemas/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,17 @@ def validate_request_params(cls, values):
"Invalid framework. Supported frameworks are TensorFlow/Keras, PyTorch, and ONNX."
)

if framework == Framework.TENSORFLOW_KERAS and not file_extension in [
Extension.H5,
Extension.ZIP,
]:
if file_extension in [Extension.H5, Extension.ZIP] and not framework == Framework.TENSORFLOW_KERAS:
raise Exception(
"Invalid file extension or framework. TensorFlow/Keras models should have .h5 or .zip extension."
"Invalid model framework. Models with .h5 or .zip extensions must use TensorFlow/Keras framework."
)
elif framework == Framework.PYTORCH and not file_extension == Extension.PT:
elif file_extension == Extension.PT and not framework == Framework.PYTORCH:
raise Exception(
"Invalid file extension or framework. PyTorch models should have .pt extension."
"Invalid model framework. Models with .pt extensions must use PyTorch framework."
)
elif framework == Framework.ONNX and not file_extension == Extension.ONNX:
elif file_extension == Extension.ONNX and not framework == Framework.ONNX:
raise Exception(
"Invalid file extension or framework. ONNX models should have .onnx extension."
"Invalid model framework. Models with .onnx extensions must use ONNX framework."
)

if framework == Framework.PYTORCH and input_layers is None:
Expand Down
12 changes: 6 additions & 6 deletions netspresso/compressor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ def __init__(self, email=None, password=None, user_session=None):
def upload_model(
self,
model_name: str,
task: Task,
framework: Framework,
file_path: str,
input_shapes: List[Dict[str, int]] = [],
task: Task = Task.OTHER,
framework: Framework = Framework.PYTORCH,
) -> Model:
"""Upload a model for compression.
Expand Down Expand Up @@ -534,9 +534,9 @@ def recommendation_compression(
recommendation_ratio: float,
input_path: str,
output_path: str,
task: Task,
framework: Framework,
input_shapes: List[Dict[str, int]],
task: Task = Task.OTHER,
framework: Framework = Framework.PYTORCH,
options: Options = Options(),
dataset_path: str = None,
) -> CompressedModel:
Expand Down Expand Up @@ -690,11 +690,11 @@ def recommendation_compression(
def automatic_compression(
self,
model_name: str,
task: Task,
framework: Framework,
input_shapes: List[Dict[str, int]],
input_path: str,
output_path: str,
task: Task = Task.OTHER,
framework: Framework = Framework.PYTORCH,
compression_ratio: float = 0.5,
) -> CompressedModel:
"""Compress a model automatically based on the given compression ratio.
Expand Down

0 comments on commit 4e661c4

Please sign in to comment.