Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hold onto named tensors to ensure they don't get garbage collected in Python #1174

Merged
merged 2 commits into from
Jan 9, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 26 additions & 20 deletions src/python/python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,13 @@ struct PyDeviceMemorySpan {
pybind11::array_t<T> py_cpu_array_;
};

struct PyNamedTensors {
PyNamedTensors(std::unique_ptr<NamedTensors> named_tensors) : named_tensors_{std::move(named_tensors)} {
}

std::unique_ptr<NamedTensors> named_tensors_;
};

struct PyGeneratorParams {
PyGeneratorParams(const Model& model) : params_{std::make_shared<GeneratorParams>(model)} {
}
Expand All @@ -238,6 +245,11 @@ struct PyGeneratorParams {
refs_.emplace_back(value);
}

void SetInputs(std::shared_ptr<PyNamedTensors> named_tensors) {
params_->SetInputs(*named_tensors->named_tensors_);
named_tensors_ = named_tensors;
}

void SetSearchOptions(const pybind11::kwargs& dict) {
for (auto& entry : dict) {
auto name = entry.first.cast<std::string>();
Expand Down Expand Up @@ -268,14 +280,8 @@ struct PyGeneratorParams {
pybind11::array py_whisper_input_features_;
pybind11::array py_alignment_heads_;

std::vector<pybind11::object> refs_; // References to data we want to ensure doesn't get garbage collected
};

struct PyNamedTensors {
PyNamedTensors(std::unique_ptr<NamedTensors> named_tensors) : named_tensors_{std::move(named_tensors)} {
}

std::unique_ptr<NamedTensors> named_tensors_;
std::vector<pybind11::object> refs_; // References to data we want to ensure doesn't get garbage collected
std::shared_ptr<PyNamedTensors> named_tensors_; // Ensure the model inputs don't get garbage collected
};

struct PyGenerator {
Expand Down Expand Up @@ -387,11 +393,11 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
// TODO(baijumeswani): Rename/redesign the whisper_input_features to be more generic
.def_readwrite("whisper_input_features", &PyGeneratorParams::py_whisper_input_features_)
.def_readwrite("alignment_heads", &PyGeneratorParams::py_alignment_heads_)
.def("set_inputs", [](PyGeneratorParams& generator_params, PyNamedTensors* named_tensors) {
.def("set_inputs", [](PyGeneratorParams& generator_params, std::shared_ptr<PyNamedTensors> named_tensors) {
if (!named_tensors || !named_tensors->named_tensors_)
throw std::runtime_error("No inputs provided.");

generator_params.params_->SetInputs(*named_tensors->named_tensors_);
generator_params.SetInputs(named_tensors);
})
.def("set_model_input", &PyGeneratorParams::SetModelInput)
.def("set_search_options", &PyGeneratorParams::SetSearchOptions) // See config.h 'struct Search' for the options
Expand Down Expand Up @@ -456,8 +462,8 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
generator.SetActiveAdapter(adapters, adapter_name);
});

pybind11::class_<Images>(m, "Images")
.def_static("open", [](pybind11::args image_paths) {
pybind11::class_<Images, std::shared_ptr<Images>>(m, "Images")
.def_static("open", [](pybind11::args image_paths) -> std::shared_ptr<Images> {
if (image_paths.empty())
throw std::runtime_error("No images provided");

Expand All @@ -470,7 +476,7 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
image_paths_vector.push_back(image_paths_string.back().c_str());
}

return LoadImages(image_paths_vector);
return std::shared_ptr<Images>(LoadImages(image_paths_vector));
})
.def_static("open_bytes", [](pybind11::args image_datas) {
if (image_datas.empty())
Expand All @@ -486,10 +492,10 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
image_raw_data[i] = ort_extensions::ImageRawData(data, data + info.size);
}

return std::make_unique<Images>(std::move(image_raw_data), image_datas.size());
return std::make_shared<Images>(std::move(image_raw_data), image_datas.size());
});

pybind11::class_<Audios>(m, "Audios")
pybind11::class_<Audios, std::shared_ptr<Audios>>(m, "Audios")
.def_static("open", [](pybind11::args audio_paths) {
if (audio_paths.empty())
throw std::runtime_error("No audios provided");
Expand All @@ -504,14 +510,14 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
audio_paths_vector.push_back(audio_paths_string.back().c_str());
}

return LoadAudios(audio_paths_vector);
return std::shared_ptr<Audios>(LoadAudios(audio_paths_vector));
});

pybind11::class_<PyNamedTensors>(m, "NamedTensors");
pybind11::class_<PyNamedTensors, std::shared_ptr<PyNamedTensors>>(m, "NamedTensors");

pybind11::class_<MultiModalProcessor, std::shared_ptr<MultiModalProcessor>>(m, "MultiModalProcessor")
.def(
"__call__", [](MultiModalProcessor& processor, const std::optional<std::string>& prompt, const pybind11::kwargs& kwargs) -> std::unique_ptr<PyNamedTensors> {
"__call__", [](MultiModalProcessor& processor, const std::optional<std::string>& prompt, const pybind11::kwargs& kwargs) -> std::shared_ptr<PyNamedTensors> {
if (kwargs.contains("images")) {
if (processor.image_processor_ == nullptr) {
throw std::runtime_error("Image processor is not available for this model.");
Expand All @@ -520,11 +526,11 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
if (!prompt.has_value()) {
throw std::runtime_error("Prompt is required for processing the image.");
}
return std::make_unique<PyNamedTensors>(
return std::make_shared<PyNamedTensors>(
processor.image_processor_->Process(*processor.tokenizer_, *prompt, images));
} else if (kwargs.contains("audios")) {
const Audios* audios = kwargs["audios"].cast<const Audios*>();
return std::make_unique<PyNamedTensors>(
return std::make_shared<PyNamedTensors>(
processor.audio_processor_->Process(audios));
} else {
throw std::runtime_error("Nothing to process.");
Expand Down
Loading