Skip to content

Commit

Permalink
dsl: support string type annotations containing '|'.
Browse files Browse the repository at this point in the history
  • Loading branch information
aszs committed Oct 27, 2024
1 parent ccce6b0 commit 0a88f27
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 11 deletions.
2 changes: 1 addition & 1 deletion docs/examples/service_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class MyApplication(tosca.nodes.SoftwareComponent):
private_address: str = Attribute()

host: Sequence[
tosca.relationships.HostedOn | tosca.nodes.Compute | tosca.capabilities.Compute
"tosca.relationships.HostedOn | tosca.nodes.Compute | tosca.capabilities.Compute"
] = ()

db: "DatabaseConnection"
Expand Down
6 changes: 5 additions & 1 deletion tosca-package/tosca/_tosca.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,8 @@ def _get_type_name(_type):
def get_optional_type(_type) -> Tuple[bool, Any]:
# if not optional return false, type
# else return true, type or type
if isinstance(_type, ForwardRef):
_type = _type.__forward_arg__
if isinstance(_type, str):
union = [t.strip() for t in _type.split("|")]
try:
Expand Down Expand Up @@ -508,7 +510,9 @@ def pytype_to_tosca_type(_type, as_str=False) -> TypeInfo:
_type = Any
origin = get_origin(_type)

if _get_type_name(origin) in ["Union", "UnionType"]:
if isinstance(_type, ForwardRef):
types: tuple = tuple(ForwardRef(t.strip()) for t in _type.__forward_arg__.split("|"))
elif _get_type_name(origin) in ["Union", "UnionType"]:
types = get_args(_type)
else:
types = (_type,)
Expand Down
10 changes: 1 addition & 9 deletions tosca-package/tosca/yaml2python.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,16 +649,8 @@ def import_types(self, types: List[str]):
return self.maybe_forward_refs(*(self.python_name_from_type(t) for t in types))

def maybe_forward_refs(self, *types) -> Sequence[str]:
def may_quote(tn):
if self._builtin_prefix:
return repr(tn)
elif tn.startswith("unfurl.") or tn.startswith("tosca."):
return tn
else:
return repr(tn)

if self.forward_refs:
return [may_quote(t) for t in types]
return [repr(t) for t in types]
else:
return types

Expand Down

0 comments on commit 0a88f27

Please sign in to comment.