Skip to content

Commit

Permalink
Fixed NAS test on Windows (#2264)
Browse files Browse the repository at this point in the history
### Changes

Use built-in `tmp_path` for temporary files to fix NAS tests on Windows

### Reason for changes

The PR (#2234) introduced a new test which fails on Windows with error:

`PermissionError: [Errno 13] Permission denied: 'C:\\Users\\SYS_K8~1\\AppData\\Local\\Temp\\tmpmf1i25nd'`


### Related tickets

124904

### Tests

NAS tests on Windows
  • Loading branch information
ljaljushkin authored Nov 14, 2023
1 parent 75a3403 commit ee435fd
Showing 1 changed file with 16 additions and 17 deletions.
33 changes: 16 additions & 17 deletions tests/torch/nas/test_elastic_width.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import tempfile

import pytest
import torch
Expand Down Expand Up @@ -154,26 +153,26 @@ def test_width_reorg(basic_model):
compare_tensors_ignoring_the_order(after_reorg, before_reorg)


def test_width_custom_external_reorg(basic_model):
def test_width_custom_external_reorg(basic_model, tmp_path):
config = get_empty_config(input_sample_sizes=basic_model.INPUT_SIZE)
external_importance = basic_model.IMPORTANCE
with tempfile.NamedTemporaryFile() as external_importance_tempfile:
torch.save(external_importance, external_importance_tempfile)
config.update(
{
"bootstrapNAS": {
"training": {
"elasticity": {
"width": {
"filter_importance": "external",
"external_importance_path": external_importance_tempfile.name,
}
},
}
external_importance_tempfile = tmp_path / "importance_file"
torch.save(external_importance, external_importance_tempfile)
config.update(
{
"bootstrapNAS": {
"training": {
"elasticity": {
"width": {
"filter_importance": "external",
"external_importance_path": external_importance_tempfile,
}
},
}
}
)
model, ctrl = create_bootstrap_training_model_and_ctrl(basic_model, config)
}
)
model, ctrl = create_bootstrap_training_model_and_ctrl(basic_model, config)
model.eval()
device = next(model.parameters()).device
dummy_input = torch.Tensor([1]).reshape(basic_model.INPUT_SIZE).to(device)
Expand Down

0 comments on commit ee435fd

Please sign in to comment.