Skip to content

Commit

Permalink
update links
Browse files Browse the repository at this point in the history
  • Loading branch information
han-cai committed Jul 19, 2023
1 parent 101fa22 commit cf2ff3a
Show file tree
Hide file tree
Showing 10 changed files with 43 additions and 43 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
![](figures/diverse_hardware.png)

## OFA-ResNet50 [[How to use]](https://github.com/mit-han-lab/once-for-all/blob/master/tutorial/ofa_resnet50_example.ipynb)
<img src="https://hanlab.mit.edu/files/OnceForAll/figures/ofa_resnst50_results.png" width="60%" />
<img src="figures/ofa_resnst50_results.png" width="60%" />

## How to use / evaluate **OFA Networks**
### Use
Expand Down
Binary file added figures/ofa_resnst50_results.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
13 changes: 7 additions & 6 deletions ofa/model_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
]


def ofa_specialized(net_id, pretrained=True):
url_base = "https://hanlab.mit.edu/files/OnceForAll/ofa_specialized/"
def ofa_specialized(net_id: str, pretrained=True):
url_base = "https://raw.githubusercontent.com/han-cai/files/master/ofa/ofa_specialized/"
net_id = net_id.replace("@", "-")
net_config = json.load(
open(
download_url(
Expand Down Expand Up @@ -88,12 +89,12 @@ def ofa_net(net_id, pretrained=True):
expand_ratio_list=[0.2, 0.25, 0.35],
width_mult_list=[0.65, 0.8, 1.0],
)
net_id = "ofa_resnet50_d=0+1+2_e=0.2+0.25+0.35_w=0.65+0.8+1.0"
net_id = "ofa_resnet50_d0+1+2_e0.2+0.25+0.35_w0.65+0.8+1.0"
else:
raise ValueError("Not supported: %s" % net_id)

if pretrained:
url_base = "https://hanlab.mit.edu/files/OnceForAll/ofa_nets/"
url_base = "https://raw.githubusercontent.com/han-cai/files/master/ofa/ofa_nets/"
init = torch.load(
download_url(url_base + net_id, model_dir=".torch/ofa_nets"),
map_location="cpu",
Expand All @@ -104,13 +105,13 @@ def ofa_net(net_id, pretrained=True):

def proxylessnas_net(net_id, pretrained=True):
net = proxyless_base(
net_config="https://hanlab.mit.edu/files/proxylessNAS/%s.config" % net_id,
net_config="https://raw.githubusercontent.com/han-cai/files/master/proxylessnas/%s.config" % net_id,
)
if pretrained:
net.load_state_dict(
torch.load(
download_url(
"https://hanlab.mit.edu/files/proxylessNAS/%s.pth" % net_id
"https://raw.githubusercontent.com/han-cai/files/master/proxylessnas/%s.pth" % net_id
),
map_location="cpu",
)["state_dict"]
Expand Down
2 changes: 1 addition & 1 deletion ofa/nas/efficiency_predictor/latency_lookup_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class LatencyTable(object):
def __init__(
self,
local_dir="~/.ofa/latency_tools/",
url="https://hanlab.mit.edu/files/proxylessNAS/LatencyTools/mobile_trim.yaml",
url="https://raw.githubusercontent.com/han-cai/files/master/proxylessnas/mobile_trim.yaml",
):
if url.startswith("http"):
fname = download_url(url, local_dir, overwrite=True)
Expand Down
2 changes: 1 addition & 1 deletion ofa/tutorial/accuracy_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, pretrained=True, device="cuda:0"):
if pretrained:
# load pretrained model
fname = download_url(
"https://hanlab.mit.edu/files/OnceForAll/tutorial/acc_predictor.pth"
"https://raw.githubusercontent.com/han-cai/files/master/ofa/acc_predictor.pth"
)
self.model.load_state_dict(
torch.load(fname, map_location=torch.device("cpu"))
Expand Down
5 changes: 2 additions & 3 deletions ofa/tutorial/latency_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ class LatencyEstimator(object):
def __init__(
self,
local_dir="~/.hancai/latency_tools/",
url="https://hanlab.mit.edu/files/proxylessNAS/LatencyTools/mobile_trim.yaml",
url="https://raw.githubusercontent.com/han-cai/files/master/proxylessnas/mobile_trim.yaml",
):
if url.startswith("http"):
fname = download_url(url, local_dir, overwrite=True)
Expand Down Expand Up @@ -198,8 +198,7 @@ def __init__(self, device="note10", resolutions=(160, 176, 192, 208, 224)):

for image_size in resolutions:
self.latency_tables[image_size] = LatencyEstimator(
url="https://hanlab.mit.edu/files/OnceForAll/tutorial/latency_table@%s/%d_lookup_table.yaml"
% (device, image_size)
url=f"https://raw.githubusercontent.com/han-cai/files/master/ofa/{device}/{image_size}_lookup_table.yaml"
)
print("Built latency table for image size: %d." % image_size)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
name="ofa",
version=VERSION,
author="MTI HAN LAB ",
author_email="hanlab.eecs+github@gmail.com",
author_email="hcai.hm@gmail.com",
url="https://github.com/mit-han-lab/once-for-all",
description="Once for All: Train One Network and Specialize it for Efficient Deployment.",
long_description=readme,
Expand Down
12 changes: 6 additions & 6 deletions train_ofa_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@
torch.cuda.set_device(hvd.local_rank())

args.teacher_path = download_url(
"https://hanlab.mit.edu/files/OnceForAll/ofa_checkpoints/ofa_D4_E6_K7",
"https://raw.githubusercontent.com/han-cai/files/master/ofa/ofa_checkpoints/ofa_D4_E6_K7",
model_dir=".torch/ofa_checkpoints/%d" % hvd.rank(),
)

Expand Down Expand Up @@ -255,7 +255,7 @@
validate_func_dict["ks_list"] = sorted(args.ks_list)
if distributed_run_manager.start_epoch == 0:
args.ofa_checkpoint_path = download_url(
"https://hanlab.mit.edu/files/OnceForAll/ofa_checkpoints/ofa_D4_E6_K7",
"https://raw.githubusercontent.com/han-cai/files/master/ofa/ofa_checkpoints/ofa_D4_E6_K7",
model_dir=".torch/ofa_checkpoints/%d" % hvd.rank(),
)
load_models(
Expand Down Expand Up @@ -284,12 +284,12 @@

if args.phase == 1:
args.ofa_checkpoint_path = download_url(
"https://hanlab.mit.edu/files/OnceForAll/ofa_checkpoints/ofa_D4_E6_K357",
"https://raw.githubusercontent.com/han-cai/files/master/ofa/ofa_checkpoints/ofa_D4_E6_K357",
model_dir=".torch/ofa_checkpoints/%d" % hvd.rank(),
)
else:
args.ofa_checkpoint_path = download_url(
"https://hanlab.mit.edu/files/OnceForAll/ofa_checkpoints/ofa_D34_E6_K357",
"https://raw.githubusercontent.com/han-cai/files/master/ofa/ofa_checkpoints/ofa_D34_E6_K357",
model_dir=".torch/ofa_checkpoints/%d" % hvd.rank(),
)
train_elastic_depth(train, distributed_run_manager, args, validate_func_dict)
Expand All @@ -300,12 +300,12 @@

if args.phase == 1:
args.ofa_checkpoint_path = download_url(
"https://hanlab.mit.edu/files/OnceForAll/ofa_checkpoints/ofa_D234_E6_K357",
"https://raw.githubusercontent.com/han-cai/files/master/ofa/ofa_checkpoints/ofa_D234_E6_K357",
model_dir=".torch/ofa_checkpoints/%d" % hvd.rank(),
)
else:
args.ofa_checkpoint_path = download_url(
"https://hanlab.mit.edu/files/OnceForAll/ofa_checkpoints/ofa_D234_E46_K357",
"https://raw.githubusercontent.com/han-cai/files/master/ofa/ofa_checkpoints/ofa_D234_E46_K357",
model_dir=".torch/ofa_checkpoints/%d" % hvd.rank(),
)
train_elastic_expand(train, distributed_run_manager, args, validate_func_dict)
Expand Down
8 changes: 4 additions & 4 deletions tutorial/ofa.ipynb

Large diffs are not rendered by default.

40 changes: 20 additions & 20 deletions tutorial/ofa_resnet50_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stdout",
Expand Down Expand Up @@ -56,7 +62,7 @@
")\n",
"\n",
"acc_predictor_checkpoint_path = download_url(\n",
" 'https://hanlab.mit.edu/files/OnceForAll/tutorial/ofa_resnet50_acc_predictor.pth',\n",
" 'https://raw.githubusercontent.com/han-cai/files/master/ofa/ofa_resnet50_acc_predictor.pth',\n",
" model_dir='~/.ofa/',\n",
")\n",
"device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
Expand All @@ -65,34 +71,34 @@
"\n",
"print('The accuracy predictor is ready!')\n",
"print(acc_predictor)"
],
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 6,
},
"outputs": [],
"source": [
"# build efficiency predictor\n",
"from ofa.nas.efficiency_predictor import ResNet50FLOPsModel\n",
"\n",
"efficiency_predictor = ResNet50FLOPsModel(ofa_network)"
],
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 11,
},
"outputs": [
{
"name": "stdout",
Expand Down Expand Up @@ -123,13 +129,7 @@
" predicted_efficiency = efficiency_predictor.get_efficiency(subnet_config)\n",
"\n",
" print(i, '\\t', predicted_acc, '\\t', '%.1fM MACs' % predicted_efficiency)\n"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
]
}
],
"metadata": {
Expand Down

0 comments on commit cf2ff3a

Please sign in to comment.