Skip to content

Commit

Permalink
dsl: support string type annotations containing '|'.
Browse files Browse the repository at this point in the history
  • Loading branch information
aszs committed Oct 27, 2024
1 parent ccce6b0 commit d1853f7
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
2 changes: 1 addition & 1 deletion docs/examples/service_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class MyApplication(tosca.nodes.SoftwareComponent):
private_address: str = Attribute()

host: Sequence[
tosca.relationships.HostedOn | tosca.nodes.Compute | tosca.capabilities.Compute
"tosca.relationships.HostedOn | tosca.nodes.Compute | tosca.capabilities.Compute"
] = ()

db: "DatabaseConnection"
Expand Down
6 changes: 5 additions & 1 deletion tosca-package/tosca/_tosca.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,8 @@ def _get_type_name(_type):
def get_optional_type(_type) -> Tuple[bool, Any]:
# if not optional return false, type
# else return true, type or type
if isinstance(_type, ForwardRef):
_type = _type.__forward_arg__
if isinstance(_type, str):
union = [t.strip() for t in _type.split("|")]
try:
Expand Down Expand Up @@ -508,7 +510,9 @@ def pytype_to_tosca_type(_type, as_str=False) -> TypeInfo:
_type = Any
origin = get_origin(_type)

if _get_type_name(origin) in ["Union", "UnionType"]:
if isinstance(_type, ForwardRef):
types = (ForwardRef(t.strip()) for t in _type.__forward_arg__.split("|"))
elif _get_type_name(origin) in ["Union", "UnionType"]:
types = get_args(_type)
else:
types = (_type,)
Expand Down

0 comments on commit d1853f7

Please sign in to comment.