Skip to content

Commit

Permalink
[MNT] update linting: limit line length to 88, add isort (#1740)
Browse files Browse the repository at this point in the history
This PR updates the linting rules to make the code more readable and
consistent with `sktime`:

* limit line length to 88 and apply throughout the code base
* add `isort` formatting and apply throughout the code base
* reorder `pyproject.toml` to ensure linting settings are at the end and
project information is at the top
  • Loading branch information
fkiraly authored Dec 26, 2024
1 parent 7a26a58 commit 533fd03
Show file tree
Hide file tree
Showing 51 changed files with 3,206 additions and 952 deletions.
16 changes: 12 additions & 4 deletions build_tools/changelog.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def fetch_latest_release(): # noqa: D103
"""
import httpx

response = httpx.get(f"{GITHUB_REPOS}/{OWNER}/{REPO}/releases/latest", headers=HEADERS)
response = httpx.get(
f"{GITHUB_REPOS}/{OWNER}/{REPO}/releases/latest", headers=HEADERS
)

if response.status_code == 200:
return response.json()
Expand Down Expand Up @@ -91,7 +93,9 @@ def fetch_pull_requests_since_last_release() -> list[dict]:
all_pulls = []
while not is_exhausted:
pulls = fetch_merged_pull_requests(page=page)
all_pulls.extend([p for p in pulls if parser.parse(p["merged_at"]) > published_at])
all_pulls.extend(
[p for p in pulls if parser.parse(p["merged_at"]) > published_at]
)
is_exhausted = any(parser.parse(p["updated_at"]) < published_at for p in pulls)
page += 1
return all_pulls
Expand All @@ -101,7 +105,9 @@ def github_compare_tags(tag_left: str, tag_right: str = "HEAD"):
"""Compare commit between two tags."""
import httpx

response = httpx.get(f"{GITHUB_REPOS}/{OWNER}/{REPO}/compare/{tag_left}...{tag_right}")
response = httpx.get(
f"{GITHUB_REPOS}/{OWNER}/{REPO}/compare/{tag_left}...{tag_right}"
)
if response.status_code == 200:
return response.json()
else:
Expand Down Expand Up @@ -135,7 +141,9 @@ def assign_prs(prs, categs: list[dict[str, list[str]]]):
# if any(l.startswith("module") for l in pr_labels):
# print(i, pr_labels)

assigned["Other"] = list(set(range(len(prs))) - {i for _, j in assigned.items() for i in j})
assigned["Other"] = list(
set(range(len(prs))) - {i for _, j in assigned.items() for i in j}
)

return assigned

Expand Down
4 changes: 3 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ def setup(app: Sphinx):
"navbar_end": ["navbar-icon-links.html", "search-field.html"],
"show_nav_level": 2,
"header_links_before_dropdown": 10,
"external_links": [{"name": "GitHub", "url": "https://github.com/sktime/pytorch-forecasting"}],
"external_links": [
{"name": "GitHub", "url": "https://github.com/sktime/pytorch-forecasting"}
],
}

html_sidebars = {
Expand Down
12 changes: 9 additions & 3 deletions examples/ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,20 @@
stop_randomization=True,
)
batch_size = 64
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0)
train_dataloader = training.to_dataloader(
train=True, batch_size=batch_size, num_workers=0
)
val_dataloader = validation.to_dataloader(
train=False, batch_size=batch_size, num_workers=0
)

# save datasets
training.save("training.pkl")
validation.save("validation.pkl")

early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=5, verbose=False, mode="min")
early_stop_callback = EarlyStopping(
monitor="val_loss", min_delta=1e-4, patience=5, verbose=False, mode="min"
)
lr_logger = LearningRateMonitor()

trainer = pl.Trainer(
Expand Down
23 changes: 18 additions & 5 deletions examples/nbeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,21 @@
add_target_scales=False,
)

validation = TimeSeriesDataSet.from_dataset(training, data, min_prediction_idx=training_cutoff)
validation = TimeSeriesDataSet.from_dataset(
training, data, min_prediction_idx=training_cutoff
)
batch_size = 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=2)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=2)
train_dataloader = training.to_dataloader(
train=True, batch_size=batch_size, num_workers=2
)
val_dataloader = validation.to_dataloader(
train=False, batch_size=batch_size, num_workers=2
)


early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min")
early_stop_callback = EarlyStopping(
monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min"
)
trainer = pl.Trainer(
max_epochs=100,
accelerator="auto",
Expand All @@ -63,7 +71,12 @@


net = NBeats.from_dataset(
training, learning_rate=3e-2, log_interval=10, log_val_interval=1, log_gradient_flow=False, weight_decay=1e-2
training,
learning_rate=3e-2,
log_interval=10,
log_val_interval=1,
log_gradient_flow=False,
weight_decay=1e-2,
)
print(f"Number of parameters in network: {net.size() / 1e3:.1f}k")

Expand Down
45 changes: 34 additions & 11 deletions examples/stallion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,16 @@
import numpy as np
from pandas.core.common import SettingWithCopyWarning

from pytorch_forecasting import GroupNormalizer, TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting import (
GroupNormalizer,
TemporalFusionTransformer,
TimeSeriesDataSet,
)
from pytorch_forecasting.data.examples import get_stallion_data
from pytorch_forecasting.metrics import QuantileLoss
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import (
optimize_hyperparameters,
)

warnings.simplefilter("error", category=SettingWithCopyWarning)

Expand All @@ -22,8 +28,12 @@

data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month
data["time_idx"] -= data["time_idx"].min()
data["avg_volume_by_sku"] = data.groupby(["time_idx", "sku"], observed=True).volume.transform("mean")
data["avg_volume_by_agency"] = data.groupby(["time_idx", "agency"], observed=True).volume.transform("mean")
data["avg_volume_by_sku"] = data.groupby(
["time_idx", "sku"], observed=True
).volume.transform("mean")
data["avg_volume_by_agency"] = data.groupby(
["time_idx", "agency"], observed=True
).volume.transform("mean")
# data = data[lambda x: (x.sku == data.iloc[0]["sku"]) & (x.agency == data.iloc[0]["agency"])]
special_days = [
"easter_day",
Expand All @@ -39,7 +49,9 @@
"beer_capital",
"music_fest",
]
data[special_days] = data[special_days].apply(lambda x: x.map({0: "", 1: x.name})).astype("category")
data[special_days] = (
data[special_days].apply(lambda x: x.map({0: "", 1: x.name})).astype("category")
)

training_cutoff = data["time_idx"].max() - 6
max_encoder_length = 36
Expand All @@ -50,14 +62,17 @@
time_idx="time_idx",
target="volume",
group_ids=["agency", "sku"],
min_encoder_length=max_encoder_length // 2, # allow encoder lengths from 0 to max_prediction_length
min_encoder_length=max_encoder_length
// 2, # allow encoder lengths from 0 to max_prediction_length
max_encoder_length=max_encoder_length,
min_prediction_length=1,
max_prediction_length=max_prediction_length,
static_categoricals=["agency", "sku"],
static_reals=["avg_population_2017", "avg_yearly_household_income_2017"],
time_varying_known_categoricals=["special_days", "month"],
variable_groups={"special_days": special_days}, # group of categorical variables can be treated as one variable
variable_groups={
"special_days": special_days
}, # group of categorical variables can be treated as one variable
time_varying_known_reals=["time_idx", "price_regular", "discount_in_percent"],
time_varying_unknown_categoricals=[],
time_varying_unknown_reals=[
Expand All @@ -78,17 +93,25 @@
)


validation = TimeSeriesDataSet.from_dataset(training, data, predict=True, stop_randomization=True)
validation = TimeSeriesDataSet.from_dataset(
training, data, predict=True, stop_randomization=True
)
batch_size = 64
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0)
train_dataloader = training.to_dataloader(
train=True, batch_size=batch_size, num_workers=0
)
val_dataloader = validation.to_dataloader(
train=False, batch_size=batch_size, num_workers=0
)


# save datasets
training.save("t raining.pkl")
validation.save("validation.pkl")

early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min")
early_stop_callback = EarlyStopping(
monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min"
)
lr_logger = LearningRateMonitor()
logger = TensorBoardLogger(log_graph=True)

Expand Down
124 changes: 64 additions & 60 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,63 +1,3 @@
[tool.ruff]
line-length = 120
exclude = [
"docs/build/",
"node_modules/",
".eggs/",
"versioneer.py",
"venv/",
".venv/",
".git/",
".history/",
]

[tool.ruff.lint]
select = ["E", "F", "W", "C4", "S"]
extend-ignore = [
"E203", # space before : (needed for how black formats slicing)
"E402", # module level import not at top of file
"E731", # do not assign a lambda expression, use a def
"E741", # ignore not easy to read variables like i l I etc.
"C406", # Unnecessary list literal - rewrite as a dict literal.
"C408", # Unnecessary dict call - rewrite as a literal.
"C409", # Unnecessary list passed to tuple() - rewrite as a tuple literal.
"F401", # unused imports
"S101", # use of assert
]

[tool.ruff.lint.isort]
known-first-party = ["pytorch_forecasting"]
combine-as-imports = true
force-sort-within-sections = true

[tool.black]
line-length = 120
include = '\.pyi?$'
exclude = '''
(
/(
\.eggs # exclude a few common directories in the
| \.git # root of the project
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
)/
| docs/build/
| node_modules/
| venve/
| .venv/
)
'''

[tool.nbqa.mutate]
ruff = 1
black = 1

[project]
name = "pytorch-forecasting"
readme = "README.md" # Markdown files are supported
Expand Down Expand Up @@ -184,3 +124,67 @@ build-backend = "setuptools.build_meta"
requires = [
"setuptools>=70.0.0",
]

[tool.ruff]
line-length = 88
exclude = [
"docs/build/",
"node_modules/",
".eggs/",
"versioneer.py",
"venv/",
".venv/",
".git/",
".history/",
]

[tool.ruff.lint]
select = ["E", "F", "W", "C4", "S"]
extend-select = [
"I", # isort
"C4", # https://pypi.org/project/flake8-comprehensions
]
extend-ignore = [
"E203", # space before : (needed for how black formats slicing)
"E402", # module level import not at top of file
"E731", # do not assign a lambda expression, use a def
"E741", # ignore not easy to read variables like i l I etc.
"C406", # Unnecessary list literal - rewrite as a dict literal.
"C408", # Unnecessary dict call - rewrite as a literal.
"C409", # Unnecessary list passed to tuple() - rewrite as a tuple literal.
"F401", # unused imports
"S101", # use of assert
]

[tool.ruff.lint.isort]
known-first-party = ["pytorch_forecasting"]
combine-as-imports = true
force-sort-within-sections = true

[tool.black]
line-length = 88
include = '\.pyi?$'
exclude = '''
(
/(
\.eggs # exclude a few common directories in the
| \.git # root of the project
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
)/
| docs/build/
| node_modules/
| venve/
| .venv/
)
'''

[tool.nbqa.mutate]
ruff = 1
black = 1
Loading

0 comments on commit 533fd03

Please sign in to comment.