diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 570a54198..8f51da931 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -11,18 +11,20 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.8] - pytorch-version: [1.7.1, 1.9.1, 1.10.1] + pytorch-version: [1.7.1, 1.9.1, 1.10.1, 2.0.1] include: - - python-version: 3.8 + - python-version: '3.8' pytorch-version: 1.7.1 torchvision-version: 0.8.2 - - python-version: 3.8 + - python-version: '3.8' pytorch-version: 1.9.1 torchvision-version: 0.10.1 - - python-version: 3.8 + - python-version: '3.8' pytorch-version: 1.10.1 torchvision-version: 0.11.2 + - python-version: '3.11' + pytorch-version: 2.0.1 + torchvision-version: 0.15.2 steps: - uses: conda-incubator/setup-miniconda@v2 - run: conda install -n test python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} torchvision=${{ matrix.torchvision-version }} cpuonly -c pytorch diff --git a/clip/clip.py b/clip/clip.py index 257511e1d..f7a5da5e6 100644 --- a/clip/clip.py +++ b/clip/clip.py @@ -145,6 +145,14 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + def _node_get(node: torch._C.Node, key: str): + """Gets attributes of a node which is polymorphic over return type. + + From https://github.com/pytorch/pytorch/pull/82628 + """ + sel = node.kindOf(key) + return getattr(node, sel)(key) + def patch_device(module): try: graphs = [module.graph] if hasattr(module, "graph") else [] @@ -156,7 +164,7 @@ def patch_device(module): for graph in graphs: for node in graph.findAllNodes("prim::Constant"): - if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"): node.copyAttributes(device_node) model.apply(patch_device) @@ -182,7 +190,7 @@ def patch_float(module): for node in graph.findAllNodes("aten::to"): inputs = list(node.inputs()) for i in [1, 2]: # dtype can be the second or third argument to aten::to() - if inputs[i].node()["value"] == 5: + if _node_get(inputs[i].node(), "value") == 5: inputs[i].node().copyAttributes(float_node) model.apply(patch_float)