diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py index d579c341..40ce6c2e 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py @@ -220,7 +220,7 @@ def export_prompt_encoder( input_batchsize = 1 if batch_input: - input_batchsize = batchsize + input_batchsize = batch_size if precision == "fp16": prompt_encoder_module = prompt_encoder_module.half()