Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed File Upload Issues in Copy Script #312

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions copy_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def extract_args(cli_args) -> tuple:
"""
Function extracts args from cli command.
"""
return cli_args.id, cli_args.source, cli_args.destination
return cli_args.id, cli_args.source, cli_args.destination, cli_args.model_config


# @TODO: Add STRATIFY = "Stratify" handling?
Expand All @@ -34,13 +34,22 @@ def extract_args(cli_args) -> tuple:
help="The destination URL to copy the project to including protocol and port "
"(if needed). E.g. http://localhost:8001",
)
parser.add_argument(
"-c",
"--model_config",
help="Boolean to determine if the model configurations should be copied.",
default="true",
)
args = parser.parse_args()
project_id, source_url, destination_url = extract_args(args)
project_id, source_url, destination_url, copy_configs = extract_args(args)

if project_id and source_url and destination_url:
print("Running copy project operation.")
copy_class = CopyProject(
pid=project_id, source=source_url, dest=destination_url
pid=project_id,
source=source_url,
dest=destination_url,
copy_configs=copy_configs == "true",
)
# Copy the project.
copy_class.copy_project_obj()
Expand Down
18 changes: 14 additions & 4 deletions dev/tools/copy_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,19 @@ class CopyProject:
"SimulateCiemssOperation",
]

def __init__(self, pid: int, source: str, dest: str):
def __init__(self, pid: int, source: str, dest: str, copy_configs: bool):
self.project_id = pid
self.source_url = source
self.destination_url = dest
self._copy_model_configurations = copy_configs

def copy_project_obj(self):
"""
Function copies base project to destination.
"""
self._fetch_project()
new_project = scrub_obj(self.source_project)
new_project["assets"] = {}
post_url = self.post_url.format(host=self.destination_url, resource="projects")
response = post_to_destination(url=post_url, body=new_project)
new_project["id"] = response["id"]
Expand Down Expand Up @@ -183,6 +185,11 @@ def validate_copy(self):
)
failed_resources = []
for entity in self.source_project_assets.keys():
if self._copy_model_configurations is False and entity in [
"model_configurations",
"simulations",
]:
continue
if len(new_project["assets"][entity]) == len(
self.source_project_assets[entity]
):
Expand Down Expand Up @@ -275,7 +282,10 @@ def _process_workflow_models(self, models: list):
raise CopyProjectFailed(message=error_msg)
if len(model["outputs"]):
for model_output in model["outputs"]:
if model_output["type"] == "modelConfigId":
if (
model_output["type"] == "modelConfigId"
and self._copy_model_configurations
):
for val in model_output["value"]:
self._process_model_config(
model_config_id=val, new_model_id=new_model_id
Expand Down Expand Up @@ -311,7 +321,7 @@ def _process_simulation_node(self, simulation_node: dict):
if (
self.id_mapper["simulations"]
and simulation_node in self.id_mapper["simulations"].keys()
):
) or self._copy_model_configurations is False:
return
new_simulation_node = simulation_node
if len(new_simulation_node["inputs"]):
Expand Down Expand Up @@ -433,7 +443,7 @@ def _upload_artifact(

requests.put(
url=upload_url["url"],
files={f"{filename}": download_file.content},
data=download_file.content,
timeout=120,
)

Expand Down
10 changes: 10 additions & 0 deletions tds/lib/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,13 @@ def adjust_project_assets(
inactive_ids = set(resource_ids) - set(assets.get(resource_type, []))
for id in inactive_ids:
session.delete(session.query(ProjectAsset).get(id))


def clean_up_asset_return(project_assets):
return_obj = {}
for asset in project_assets:
if asset.resource_type not in return_obj:
return_obj[asset.resource_type] = []
return_obj[asset.resource_type].append(asset.resource_id)

return return_obj
17 changes: 11 additions & 6 deletions tds/modules/project/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import Query, Session
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.orm.exc import NoResultFound

from tds.db import entry_exists, es_client, request_rdb
from tds.db.enums import ResourceType
from tds.lib.projects import clean_up_asset_return
from tds.modules.project.helpers import (
ResourceDoesNotExist,
adjust_project_assets,
Expand Down Expand Up @@ -102,18 +103,22 @@ def project_get(project_id: int, rdb: Engine = Depends(request_rdb)) -> JSONResp
try:
if entry_exists(rdb.connect(), Project, project_id):
with Session(rdb) as session:
project = session.query(Project).get(project_id)
# pylint: disable-next=unused-variable
parameters: Query[ProjectAsset] = session.query(ProjectAsset).filter(
ProjectAsset.project_id == project_id
project = (
session.query(Project)
.options(joinedload(Project.assets))
.get(project_id)
)
project_response = {
**project.__dict__,
"assets": clean_up_asset_return(project.assets),
}
else:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)

return JSONResponse(
status_code=status.HTTP_200_OK,
headers={"content-type": "application/json"},
content=jsonable_encoder(project),
content=jsonable_encoder(project_response),
)
except NoResultFound:
return JSONResponse(
Expand Down
Loading