You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
from huggingface_hub import hf_hub_download
import torch
from transformers import AutoformerForPrediction
file = hf_hub_download(
repo_id="hf-internal-testing/tourism-monthly-batch", filename="train-batch.pt", repo_type="dataset"
)
batch = torch.load(file)
model = AutoformerForPrediction.from_pretrained("huggingface/autoformer-tourism-monthly")
# during training, one provides both past and future values
# as well as possible additional features
outputs = model(
past_values=batch["past_values"],
past_time_features=batch["past_time_features"],
past_observed_mask=batch["past_observed_mask"],
static_categorical_features=batch["static_categorical_features"],
static_real_features=batch["static_real_features"],
future_values=batch["future_values"],
future_time_features=batch["future_time_features"],
)
loss = outputs.loss
loss.backward()
# during inference, one only provides past values
# as well as possible additional features
# the model autoregressively generates future values
outputs = model.generate(
past_values=batch["past_values"],
past_time_features=batch["past_time_features"],
past_observed_mask=batch["past_observed_mask"],
static_categorical_features=batch["static_categorical_features"],
static_real_features=batch["static_real_features"],
future_time_features=batch["future_time_features"],
)
mean_prediction = outputs.sequences.mean(dim=1)
在outputs = model(...)出现了矩阵维度不匹配的bug: RuntimeError: mat 1 and mat 2 shapes cannot be multiplied(1536x23 and 22x64)
运行huggingface关于AutoformerForPrediction的演示代码
在
outputs = model(...)
出现了矩阵维度不匹配的bug:RuntimeError: mat 1 and mat 2 shapes cannot be multiplied(1536x23 and 22x64)
对应数据集中,bs=64, 输入长度=61, 预测长度=24, 有两个时间特征. 本人能力有限只能看出来1536=64*24, 其他几个维度实在是找不到规律所在. 而在前面
AutoformerModel
的demo与之相似,但在outputs = model(...)
这步却没有报错. 请问应该如何解决? 感激不尽!!The text was updated successfully, but these errors were encountered: