From 6eaf71419f033f43b1a5dbaceb6fa6dfbd4fcb32 Mon Sep 17 00:00:00 2001 From: Mehrdad Khani Date: Mon, 4 Nov 2024 15:24:23 -0800 Subject: [PATCH] [XLA:MSA] Remove unnecessary Extend() call in memory space assignment. This Extend() call would also lead to a memory assignment issue since it wasn't accompanied by the necessary chunk commit requests. We also add a VerifyAllocations() function that uses a BufferIntervalTree to check for overlapping Allocations before scheduling the asynchronous copies. This is an extra check for the correctness of MsaAlgorithm allocations, and is only applied if options_.verify is enabled in MSA options. options_.verify is disabled by default. PiperOrigin-RevId: 693110290 --- xla/service/memory_space_assignment/BUILD | 16 +- .../memory_space_assignment/algorithm.cc | 213 ++++++------ .../memory_space_assignment/algorithm.h | 321 ++---------------- .../allocation_value.h | 300 ++++++++++++++++ .../memory_space_assignment.cc | 34 ++ .../memory_space_assignment.h | 5 + .../memory_space_assignment_test.cc | 169 ++++++++- xla/service/memory_space_assignment/options.h | 16 +- 8 files changed, 677 insertions(+), 397 deletions(-) create mode 100644 xla/service/memory_space_assignment/allocation_value.h diff --git a/xla/service/memory_space_assignment/BUILD b/xla/service/memory_space_assignment/BUILD index 9b46280d9545c..0c8166943fcba 100644 --- a/xla/service/memory_space_assignment/BUILD +++ b/xla/service/memory_space_assignment/BUILD @@ -79,6 +79,7 @@ xla_cc_test( deps = [ ":algorithm", ":allocation", + ":allocation_value", ":buffer_interval_comparator", ":cost_analysis", ":memory_space_assignment", @@ -107,7 +108,6 @@ xla_cc_test( "//xla/service/heap_simulator:allocation_block", "//xla/tests:hlo_test_base", "//xla/tests:test_utils", - "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", @@ -273,6 +273,7 @@ cc_library( srcs = [], hdrs = ["options.h"], deps = [ + ":allocation_value", ":buffer_interval_comparator", ":cost_analysis", ":memory_space_assignment_proto_cc", @@ -508,6 +509,7 @@ cc_library( hdrs = ["algorithm.h"], deps = [ ":allocation", + ":allocation_value", ":buffer_interval_comparator", ":cost_analysis", ":memory_bound_loop_optimizer", @@ -577,6 +579,18 @@ cc_library( ], ) +cc_library( + name = "allocation_value", + hdrs = ["allocation_value.h"], + deps = [ + ":allocation", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_value", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/types:span", + ], +) + xla_cc_test( name = "prefetch_interval_picker_test", srcs = ["prefetch_interval_picker_test.cc"], diff --git a/xla/service/memory_space_assignment/algorithm.cc b/xla/service/memory_space_assignment/algorithm.cc index bfc731f9d11e4..1f61efd4fb76c 100644 --- a/xla/service/memory_space_assignment/algorithm.cc +++ b/xla/service/memory_space_assignment/algorithm.cc @@ -63,6 +63,7 @@ limitations under the License. #include "xla/service/hlo_buffer.h" #include "xla/service/hlo_value.h" #include "xla/service/memory_space_assignment/allocation.h" +#include "xla/service/memory_space_assignment/allocation_value.h" #include "xla/service/memory_space_assignment/buffer_interval_comparator.h" #include "xla/service/memory_space_assignment/cost_analysis.h" #include "xla/service/memory_space_assignment/memory_bound_loop_optimizer.h" @@ -1828,7 +1829,7 @@ absl::StatusOr> MsaAlgorithm::Finish() { } options_.prefetch_interval_picker->SetRetryNumber(retry_number); TF_ASSIGN_OR_RETURN( - Result result, + AllocationResult result, AllocateAllocationValues(absl::MakeSpan(proposal.allocation_values))); VLOG(2) << "Allocation result = " << ResultToString(result); VLOG(3) << "--Allocations List Begin--"; @@ -1845,10 +1846,11 @@ absl::StatusOr> MsaAlgorithm::Finish() { } } VLOG(3) << "--Allocations List End--"; - if (result_is(result, Result::kFailSyncDataMoveReplacement)) { + if (result_is(result, AllocationResult::kFailSyncDataMoveReplacement)) { CHECK(options_.enable_sync_copy_replacement || options_.enable_sync_slice_replacement) - << "Allocation result is Result::kFailSyncCopyReplacement, but " + << "Allocation result is " + "AllocationResult::kFailSyncCopyReplacement, but " "no sync replacement is enabled."; UncommitPendingChunks(absl::MakeSpan(proposal.allocation_values)); proposal = GetJointProposal(interval); @@ -1871,7 +1873,7 @@ absl::StatusOr> MsaAlgorithm::Finish() { proposal = GetJointProposal(interval); --retry_number; } - } else if ((result_is(result, Result::kFailOutOfMemory) || + } else if ((result_is(result, AllocationResult::kFailOutOfMemory) || options_.repack_after_every_allocation) && num_repacks_ < options_.max_repacks && !repacked && !RepackAllocationsIncludeConvertedSyncMemOp()) { @@ -2366,7 +2368,7 @@ MsaAlgorithm::GenerateAllocationSegmentContexts( return uses_work_list; } -absl::StatusOr MsaAlgorithm::AllocateAllocationValues( +absl::StatusOr MsaAlgorithm::AllocateAllocationValues( absl::Span allocation_values) { const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); absl::flat_hash_map> @@ -2400,7 +2402,7 @@ absl::StatusOr MsaAlgorithm::AllocateAllocationValues( preferred_offset_for_allocation_value; absl::flat_hash_map definition_time_for_allocation_value; - Result result = Result::kSuccess; + AllocationResult result = AllocationResult::kSuccess; for (int alloc_value_idx = 0; alloc_value_idx < allocation_values.size(); ++alloc_value_idx) { auto& allocation_value = allocation_values.at(alloc_value_idx); @@ -2473,6 +2475,9 @@ absl::StatusOr MsaAlgorithm::AllocateAllocationValues( definition_time_for_allocation_value.at(&allocation_value_to_update), RequiresNoCopyAlternateMemAllocation(allocation_value_to_update), all_use_times, entry.only_extend_existing_allocation); + if (options_.allocation_request_modifier_testing_fn) { + options_.allocation_request_modifier_testing_fn(request); + } // Bitcasts don't define buffers and don't directly consume buffers. // Skip allocating buffers for bitcast uses (unless they are the root // instruction). The uses that feed from bitcasts will be handled @@ -2481,6 +2486,9 @@ absl::StatusOr MsaAlgorithm::AllocateAllocationValues( use.hlo_use.instruction == use.hlo_use.instruction->parent()->root_instruction()) { result_mark(AllocateSegment(request), result); + if (options_.allocation_result_modifier_testing_fn) { + options_.allocation_result_modifier_testing_fn(request, result); + } if (request.require_copy_allocation) { auto allocation_sequence = allocation_value_to_update.mutable_allocation_sequence(); @@ -2523,8 +2531,8 @@ absl::StatusOr MsaAlgorithm::AllocateAllocationValues( "normal mode."; failed_async_conversions_[request.required_copy_allocation_for] = AsyncConversionResult::kFailedSatisfyingConstraints; - result_mark(Result::kFailSyncDataMoveReplacement, result); - result_mark(Result::kFailRequiresUncommit, result); + result_mark(AllocationResult::kFailSyncDataMoveReplacement, result); + result_mark(AllocationResult::kFailRequiresUncommit, result); } else { bool has_correct_use = false; for (auto& alloc_use : (*it)->uses()) { @@ -2540,8 +2548,9 @@ absl::StatusOr MsaAlgorithm::AllocateAllocationValues( "normal mode."; failed_async_conversions_[request.required_copy_allocation_for] = AsyncConversionResult::kFailedPrecondition; - result_mark(Result::kFailSyncDataMoveReplacement, result); - result_mark(Result::kFailRequiresUncommit, result); + result_mark(AllocationResult::kFailSyncDataMoveReplacement, + result); + result_mark(AllocationResult::kFailRequiresUncommit, result); } else { not_finalized_async_conversions_.push_back( request.required_copy_allocation_for); @@ -2552,7 +2561,7 @@ absl::StatusOr MsaAlgorithm::AllocateAllocationValues( } } if (request.require_no_copy_alternate_mem_allocation && - result != Result::kSuccess) { + result != AllocationResult::kSuccess) { absl::Status failed_precondition = FailedPrecondition( "The value defined at %s requires allocation in the alternate " "memory, which could not be satisfied. This typically happens " @@ -2588,8 +2597,8 @@ absl::StatusOr MsaAlgorithm::AllocateAllocationValues( } if (!VerifyAllConversionsAreSuccessful()) { - result_mark(Result::kFailSyncDataMoveReplacement, result); - result_mark(Result::kFailRequiresUncommit, result); + result_mark(AllocationResult::kFailSyncDataMoveReplacement, result); + result_mark(AllocationResult::kFailRequiresUncommit, result); } return result; @@ -2613,9 +2622,8 @@ bool MsaAlgorithm::VerifyAllConversionsAreSuccessful() { return true; } -MsaAlgorithm::AliasedOffset* MsaAlgorithm::UpdatePreferredOffsetForUse( - const AllocationValue::Use& use, - MsaAlgorithm::AliasedOffset* preferred_offset) const { +AliasedOffset* MsaAlgorithm::UpdatePreferredOffsetForUse( + const AllocationValue::Use& use, AliasedOffset* preferred_offset) const { // Assign the required assignment offset as a preferred offset. std::optional required_assignment = AliasedRequiredAssignmentForUse(use); @@ -2632,7 +2640,7 @@ MsaAlgorithm::AliasedOffset* MsaAlgorithm::UpdatePreferredOffsetForUse( return preferred_offset; } -MsaAlgorithm::AllocationRequest MsaAlgorithm::CreateAllocationRequest( +AllocationRequest MsaAlgorithm::CreateAllocationRequest( AllocationValue& allocation_value, AllocationValue& allocation_value_to_update, const AllocationValue::Use& use, const AllocationValue::Use* previous_use, @@ -3313,15 +3321,14 @@ std::string AsynchronousCopyResource::Dump( return absl::StrJoin(lines, "\n"); } -MsaAlgorithm::AliasedOffset* MsaAlgorithm::GetAliasedOffset( - const Allocation& allocation) { +AliasedOffset* MsaAlgorithm::GetAliasedOffset(const Allocation& allocation) { auto aliased_offset_it = aliased_offset_map_.find(&allocation); CHECK(aliased_offset_it != aliased_offset_map_.end()); return aliased_offset_it->second; } -void MsaAlgorithm::CreateOrAddToAliasedOffset( - const Allocation& allocation, MsaAlgorithm::AliasedOffset* aliased_offset) { +void MsaAlgorithm::CreateOrAddToAliasedOffset(const Allocation& allocation, + AliasedOffset* aliased_offset) { CHECK(allocation.memory_space() == MemorySpace::kAlternate); CHECK(!aliased_offset_map_.contains(&allocation)); if (!aliased_offset) { @@ -4284,43 +4291,45 @@ std::optional MsaAlgorithm::FindEarliestExclusiveTimeToSatisfyPeakMemory( return earliest_time_exclusive; } -std::string MsaAlgorithm::SingleFailureResultToString(const Result& result) { +std::string MsaAlgorithm::SingleFailureResultToString( + const AllocationResult& result) { switch (result) { - case Result::kSuccess: + case AllocationResult::kSuccess: return "Success"; - case Result::kFailOutOfMemory: + case AllocationResult::kFailOutOfMemory: return "FailOutOfMemory"; - case Result::kFailPrevAllocationNotInAlternateMem: + case AllocationResult::kFailPrevAllocationNotInAlternateMem: return "FailPrevAllocationNotInAlternateMem"; - case Result::kFailLiveRangeTooLong: + case AllocationResult::kFailLiveRangeTooLong: return "FailLiveRangeTooLong"; - case Result::kFailLiveRangeTooShort: + case AllocationResult::kFailLiveRangeTooShort: return "FailLiveRangeTooShort"; - case Result::kFailOutOfAsyncCopies: + case AllocationResult::kFailOutOfAsyncCopies: return "FailOutOfAsyncCopies"; - case Result::kFailViolatesAsyncCopyResource: + case AllocationResult::kFailViolatesAsyncCopyResource: return "FailViolatesAsyncCopyResource"; - case Result::kFailRequiresUncommit: + case AllocationResult::kFailRequiresUncommit: return "FailRequiresUncommit"; - case Result::kAllSlicesHaveTheSameStartTime: + case AllocationResult::kAllSlicesHaveTheSameStartTime: return "AllSlicesHaveTheSameStartTime"; - case Result::kFailConflictingPreferredOffsets: + case AllocationResult::kFailConflictingPreferredOffsets: return "FailConflictingPreferredOffsets"; - case Result::kFailSyncDataMoveReplacement: + case AllocationResult::kFailSyncDataMoveReplacement: return "FailSyncDataMoveReplacement"; default: return "UnknownResult"; } } -std::string MsaAlgorithm::ResultToString(const Result& result) { - if (result == Result::kSuccess) { +std::string MsaAlgorithm::ResultToString(const AllocationResult& result) { + if (result == AllocationResult::kSuccess) { return "Success"; } std::string result_str = ""; for (int failure_order = 0; failure_order < 16; ++failure_order) { - Result failure_value = static_cast(1 << failure_order); - if (result_is(result, static_cast(failure_value))) { + AllocationResult failure_value = + static_cast(1 << failure_order); + if (result_is(result, static_cast(failure_value))) { result_str += (SingleFailureResultToString(failure_value) + " | "); } } @@ -4328,7 +4337,7 @@ std::string MsaAlgorithm::ResultToString(const Result& result) { return result_str; } -MsaAlgorithm::Result MsaAlgorithm::AllocateSegment(AllocationRequest& request) { +AllocationResult MsaAlgorithm::AllocateSegment(AllocationRequest& request) { auto allocation_sequence = request.allocation_value->mutable_allocation_sequence(); // inclusive_start_time == end_time is a special case where the value is @@ -4340,7 +4349,7 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateSegment(AllocationRequest& request) { request.end_time); CHECK_NE(allocation, nullptr); allocation->AddUse(request.use->hlo_use); - return Result::kSuccess; + return AllocationResult::kSuccess; } const HloPosition& defining_position = @@ -4375,10 +4384,6 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateSegment(AllocationRequest& request) { *use.instruction, use.operand_number, use.operand_index); } - if (request.only_extend_existing_allocation && - !allocation_sequence->empty()) { - allocation_sequence->back()->Extend(request.inclusive_start_time); - } // There could be a requirement to pin this buffer to default memory either // because it is a parameter or an output. If the buffer is a parameter, then // we're allowed to prefetch. If the use expects the output to be in default @@ -4443,15 +4448,15 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateSegment(AllocationRequest& request) { } } - Result allocation_result = Result::kSuccess; + AllocationResult allocation_result = AllocationResult::kSuccess; // First try keeping the allocation entirely in the alternate memory. if (required_memory_space_at_start != MemorySpace::kDefault && required_memory_space_at_end != MemorySpace::kDefault && request.allow_no_copy_alternate_mem_allocation && !request.require_copy_allocation) { allocation_result = AllocateInAlternateMemoryNoCopy(request); - if (allocation_result == Result::kSuccess) { - return Result::kSuccess; + if (allocation_result == AllocationResult::kSuccess) { + return AllocationResult::kSuccess; } // If we required alternate memory allocation, return on failure. if (request.require_no_copy_alternate_mem_allocation) { @@ -4477,10 +4482,11 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateSegment(AllocationRequest& request) { (*prev_allocation_it)->defining_position() == defining_position) { // If there was an allocation for this HloValue that was in the alternate // memory space, we also need to perform an eviction. - Result eviction_result = Evict(request); - if (eviction_result != Result::kSuccess) { + AllocationResult eviction_result = Evict(request); + if (eviction_result != AllocationResult::kSuccess) { // A non-success eviction requires us to uncommit previous allocations. - return result_mark(Result::kFailRequiresUncommit, eviction_result); + return result_mark(AllocationResult::kFailRequiresUncommit, + eviction_result); } prev_allocation_in_default_mem_it = allocation_sequence->rbegin(); } else if (prev_allocation_in_default_mem_it == @@ -4495,7 +4501,8 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateSegment(AllocationRequest& request) { } else if (prev_allocation_in_default_mem_it == allocation_sequence->rend()) { VLOG(3) << "Allocation requires contiguous allocation, but it wasn't " "possible to find one."; - return result_mark(Result::kFailRequiresUncommit, allocation_result); + return result_mark(AllocationResult::kFailRequiresUncommit, + allocation_result); } CHECK(prev_allocation_in_default_mem_it != allocation_sequence->rend()); @@ -4511,7 +4518,8 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateSegment(AllocationRequest& request) { required_memory_space_at_start != required_memory_space_at_end) { VLOG(3) << "Allocation requires contiguous allocation but has memory space " "mismatch."; - return result_mark(Result::kFailRequiresUncommit, allocation_result); + return result_mark(AllocationResult::kFailRequiresUncommit, + allocation_result); } // If the buffer must be in default memory at the end_time, don't prefetch. @@ -4525,7 +4533,7 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateSegment(AllocationRequest& request) { // prefetching it, which will try to prefetch only a window worth of data to // alternate memory. WindowPrefetch(request, **prev_allocation_in_default_mem_it); - return Result::kSuccess; + return AllocationResult::kSuccess; } // Finally, try to prefetch the buffer into alternate memory. @@ -4555,9 +4563,9 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateSegment(AllocationRequest& request) { request.latest_prefetch_time = latest_prefetch_time; } } - Result prefetch_result = + AllocationResult prefetch_result = Prefetch(request, **prev_allocation_in_default_mem_it); - if (prefetch_result == Result::kSuccess) { + if (prefetch_result == AllocationResult::kSuccess) { if (request.preferred_prefetch_time) { // Warn if the prefetch time picked doesn't match the preferred prefetch // time. @@ -4587,7 +4595,7 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateSegment(AllocationRequest& request) { << "): " << request.use->hlo_use.ToString(); } } - return Result::kSuccess; + return AllocationResult::kSuccess; } // Warn if there was a preferred prefetch time but we couldn't actually // prefetch. @@ -4603,7 +4611,8 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateSegment(AllocationRequest& request) { // If the end assignment was required to be in alternate memory but that // wasn't possible, then this allocation is invalid. if (required_memory_space_at_end == MemorySpace::kAlternate) { - return result_mark(Result::kFailRequiresUncommit, allocation_result); + return result_mark(AllocationResult::kFailRequiresUncommit, + allocation_result); } // If the start assignment was required to be in alternate memory and the @@ -4611,7 +4620,8 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateSegment(AllocationRequest& request) { // and must abort. if (required_memory_space_at_start == MemorySpace::kAlternate && request.allocation_value->requires_contiguous_allocation()) { - return result_mark(Result::kFailRequiresUncommit, allocation_result); + return result_mark(AllocationResult::kFailRequiresUncommit, + allocation_result); } // If a copy wasn't inserted, then add this use to the latest allocation in @@ -4794,7 +4804,7 @@ bool MsaAlgorithm::ViolatesMaximumOutstandingAsyncCopies( } } -MsaAlgorithm::Result MsaAlgorithm::AllocateInAlternateMemoryNoCopy( +AllocationResult MsaAlgorithm::AllocateInAlternateMemoryNoCopy( const AllocationRequest& request) { Allocation* prev_allocation = nullptr; bool can_eliminate_copy = false; @@ -4815,7 +4825,7 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateInAlternateMemoryNoCopy( if (!can_eliminate_copy) { VLOG(3) << "Can't eliminate copy."; - return Result::kFailPrevAllocationNotInAlternateMem; + return AllocationResult::kFailPrevAllocationNotInAlternateMem; } const HloPosition& defining_position = @@ -4828,7 +4838,7 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateInAlternateMemoryNoCopy( defining_position.shape(), request.inclusive_start_time, request.end_time)) { VLOG(3) << "Live range is too long."; - return Result::kFailLiveRangeTooLong; + return AllocationResult::kFailLiveRangeTooLong; } MsaBufferInterval alternate_mem_interval; @@ -4854,7 +4864,7 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateInAlternateMemoryNoCopy( "preferred_offset = " << preferred_offset->offset << ", request.preferred_offset = " << request.preferred_offset->offset; - return Result::kFailConflictingPreferredOffsets; + return AllocationResult::kFailConflictingPreferredOffsets; } preferred_offset = request.preferred_offset; } @@ -4918,16 +4928,16 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateInAlternateMemoryNoCopy( request.allocation_value->allocation_sequence()->back()->AddUse( request.use->hlo_use); } - return Result::kSuccess; + return AllocationResult::kSuccess; } if (request.prefer_no_copy_alternate_mem_allocation) { VLOG(1) << "Preferred no-copy allocation, but this was not possible: " << request.use->hlo_use.ToString(); } - return Result::kFailOutOfMemory; + return AllocationResult::kFailOutOfMemory; } -MsaAlgorithm::Result MsaAlgorithm::Evict(const AllocationRequest& request) { +AllocationResult MsaAlgorithm::Evict(const AllocationRequest& request) { CHECK_GT(request.allocation_value->allocation_sequence()->size(), 0); Allocation* prev_allocation = request.allocation_value->allocation_sequence()->back().get(); @@ -5053,10 +5063,10 @@ MsaAlgorithm::Result MsaAlgorithm::Evict(const AllocationRequest& request) { << hlo_live_range_.flattened_instruction_sequence() .instructions()[eviction_end_time] << ")"; - return Result::kFailOutOfAsyncCopies; + return AllocationResult::kFailOutOfAsyncCopies; } } - return Result::kSuccess; + return AllocationResult::kSuccess; } int64_t MsaAlgorithm::FindPrefetchEndTime( @@ -5091,11 +5101,11 @@ std::string DescribeSlicedBufferMove( } // namespace -MsaAlgorithm::Result MsaAlgorithm::WindowPrefetch( +AllocationResult MsaAlgorithm::WindowPrefetch( const AllocationRequest& request, Allocation& prev_allocation_in_default_mem) { if (!options_.enable_window_prefetch) { - return Result::kSuccess; + return AllocationResult::kSuccess; } const HloUse use = request.use->hlo_use; @@ -5120,10 +5130,10 @@ MsaAlgorithm::Result MsaAlgorithm::WindowPrefetch( const Shape shape = ShapeUtil::MakeShape(U8, {window.size()}); Prefetch(window_prefetch_request, prev_allocation_in_default_mem, &shape); } - return Result::kSuccess; + return AllocationResult::kSuccess; } -MsaAlgorithm::Result MsaAlgorithm::Prefetch( +AllocationResult MsaAlgorithm::Prefetch( const AllocationRequest& request, Allocation& prev_allocation_in_default_mem, const Shape* shape) { // Try partially placing the buffer in the alternate space. The time that is @@ -5157,12 +5167,12 @@ MsaAlgorithm::Result MsaAlgorithm::Prefetch( SetupPrefetchWorkingIntervalsAndSliceProposal(context); // Compute some additional preliminaries - Result init_result = InitializePrefetchIntervalPicker(context); - if (init_result != Result::kSuccess) { + AllocationResult init_result = InitializePrefetchIntervalPicker(context); + if (init_result != AllocationResult::kSuccess) { return init_result; } - Result check_result = EnsureSomeSpatialPrefetchFitExists(context); - if (check_result != Result::kSuccess) { + AllocationResult check_result = EnsureSomeSpatialPrefetchFitExists(context); + if (check_result != AllocationResult::kSuccess) { return check_result; } const HloUse& use = request.use->hlo_use; @@ -5191,7 +5201,7 @@ MsaAlgorithm::Result MsaAlgorithm::Prefetch( // request and a non-sliced version of the request. We return the first sliced // solution that we find. We fallback to the first unsliced solution we find, // if we are unable to find a sliced solution. - Result result = Result::kSuccess; + AllocationResult result = AllocationResult::kSuccess; while (!options_.prefetch_interval_picker->Done()) { // Get the prefetch start time from the interval picker. context.exclusive_prefetch_start_time = @@ -5201,19 +5211,20 @@ MsaAlgorithm::Result MsaAlgorithm::Prefetch( context.exclusive_prefetch_start_time <= *context.exclusive_out_of_mem_start) { VLOG(4) << "This would OOM (cached)."; - return Result::kFailOutOfMemory; + return AllocationResult::kFailOutOfMemory; } if (context.slice_proposal_collection) { VLOG(5) << "Trying sliced solution."; // Check if a sliced solution fits. - Result sliced_result = + AllocationResult sliced_result = CheckPrefetchFit(/*for_sliced_solution=*/true, context); - if (sliced_result == Result::kSuccess) { + if (sliced_result == AllocationResult::kSuccess) { // Break out of the loop and use the sliced solution. CHECK(context.sliced_solution); break; - } else if (sliced_result != Result::kAllSlicesHaveTheSameStartTime) { + } else if (sliced_result != + AllocationResult::kAllSlicesHaveTheSameStartTime) { result_mark(sliced_result, result); } } @@ -5221,9 +5232,9 @@ MsaAlgorithm::Result MsaAlgorithm::Prefetch( // If we don't already have an unsliced solution, check the current fit. if (!context.unsliced_solution) { VLOG(5) << "Trying unsliced solution."; - Result unsliced_result = + AllocationResult unsliced_result = CheckPrefetchFit(/*for_sliced_solution=*/false, context); - if (unsliced_result != Result::kSuccess) { + if (unsliced_result != AllocationResult::kSuccess) { result_mark(unsliced_result, result); } else if (!context.slice_proposal_collection) { // We found an unsliced solution and there is no slice proposal, so @@ -5257,7 +5268,7 @@ MsaAlgorithm::Result MsaAlgorithm::Prefetch( context.request->allocation_value_to_update->allocation_sequence() ->back() ->AddUse(context.request->use->hlo_use); - return Result::kSuccess; + return AllocationResult::kSuccess; } if (context.unsliced_solution) { VLOG(3) << "Move the buffer to alternate memory after time " @@ -5302,12 +5313,14 @@ MsaAlgorithm::Result MsaAlgorithm::Prefetch( request.allocation_value_to_update->allocation_sequence()->back()->AddUse( request.use->hlo_use); - return Result::kSuccess; + return AllocationResult::kSuccess; } // If we didn't consider any prefetch intervals, then the live range was too // short. - return (result == Result::kSuccess ? Result::kFailLiveRangeTooShort : result); + return (result == AllocationResult::kSuccess + ? AllocationResult::kFailLiveRangeTooShort + : result); } void MsaAlgorithm::GenerateSliceProposal(PrefetchContext& context) const { @@ -5403,7 +5416,7 @@ void MsaAlgorithm::SetupPrefetchWorkingIntervalsAndSliceProposal( context.unsliced_solution_intervals.full)); } -MsaAlgorithm::Result MsaAlgorithm::InitializePrefetchIntervalPicker( +AllocationResult MsaAlgorithm::InitializePrefetchIntervalPicker( PrefetchContext& context) { int64_t earliest_exclusive_prefetch_time = context.prev_allocation_in_default_mem->earliest_available_time(); @@ -5425,7 +5438,7 @@ MsaAlgorithm::Result MsaAlgorithm::InitializePrefetchIntervalPicker( VLOG(3) << "Any prefetch in range (" << earliest_exclusive_prefetch_time << ", " << context.prefetch_end_time << ") for size " << context.request->size << " would go out of memory."; - return Result::kFailOutOfMemory; + return AllocationResult::kFailOutOfMemory; } if (!context.slice_proposal_collection) { // We can only perform this optimization if we are not slicing. @@ -5452,10 +5465,10 @@ MsaAlgorithm::Result MsaAlgorithm::InitializePrefetchIntervalPicker( VLOG(3) << "Trying prefetch picker = " << options_.prefetch_interval_picker->ToDebugString(); - return Result::kSuccess; + return AllocationResult::kSuccess; } -MsaAlgorithm::Result MsaAlgorithm::EnsureSomeSpatialPrefetchFitExists( +AllocationResult MsaAlgorithm::EnsureSomeSpatialPrefetchFitExists( PrefetchContext& context) const { SlicedBufferInterval* interval = (context.slice_proposal_collection @@ -5473,10 +5486,10 @@ MsaAlgorithm::Result MsaAlgorithm::EnsureSomeSpatialPrefetchFitExists( VLOG(3) << "The latest prefetch (" << interval->full_buffer_interval().start << ", " << context.request->end_time << ") cannot find valid chunks. Giving up."; - return Result::kFailOutOfMemory; + return AllocationResult::kFailOutOfMemory; } - return Result::kSuccess; + return AllocationResult::kSuccess; } namespace { @@ -5588,8 +5601,8 @@ absl::flat_hash_map GetCandidateToProposalIndexMap( } // namespace -MsaAlgorithm::Result MsaAlgorithm::CheckPrefetchFit(bool for_sliced_solution, - PrefetchContext& context) { +AllocationResult MsaAlgorithm::CheckPrefetchFit(bool for_sliced_solution, + PrefetchContext& context) { SlicedBufferInterval* sliced_buffer_interval = context.GetMutableWorkingIntervals(for_sliced_solution).sliced.get(); @@ -5631,7 +5644,7 @@ MsaAlgorithm::Result MsaAlgorithm::CheckPrefetchFit(bool for_sliced_solution, exclusive_slice_start_times, [&](int64_t slice_start_time) { return slice_start_time == exclusive_slice_start_times.front(); })) { - return Result::kAllSlicesHaveTheSameStartTime; + return AllocationResult::kAllSlicesHaveTheSameStartTime; } // Check that we have enough copy resource for the prefetching. @@ -5666,7 +5679,7 @@ MsaAlgorithm::Result MsaAlgorithm::CheckPrefetchFit(bool for_sliced_solution, context.prefetch_end_time, copy_resource_per_slice_sorted_by_start_time, prefetch_async_copy_resource_)) { - return Result::kFailViolatesAsyncCopyResource; + return AllocationResult::kFailViolatesAsyncCopyResource; } // Check if the copies we would add for the prefetch would violate copy @@ -5678,7 +5691,7 @@ MsaAlgorithm::Result MsaAlgorithm::CheckPrefetchFit(bool for_sliced_solution, slice_start_time, context.prefetch_end_time); })) { VLOG(4) << "This would violate asynchronous copy ordering."; - return Result::kFailViolatesAsyncCopyResource; + return AllocationResult::kFailViolatesAsyncCopyResource; } // Check if the copies we would add for the prefetch violate the maximum @@ -5688,7 +5701,7 @@ MsaAlgorithm::Result MsaAlgorithm::CheckPrefetchFit(bool for_sliced_solution, exclusive_slice_start_times[i], context.prefetch_end_time, /*is_prefetch=*/true, context.extra_async_copy_limit, i)) { VLOG(4) << "This would violate the outstanding async copy limit."; - return Result::kFailOutOfAsyncCopies; + return AllocationResult::kFailOutOfAsyncCopies; } } @@ -5730,7 +5743,7 @@ MsaAlgorithm::Result MsaAlgorithm::CheckPrefetchFit(bool for_sliced_solution, exclusive_slice_start_times, context.prefetch_end_time, copy_resource_per_slice_sorted_by_start_time, prefetch_async_copy_resource_)) { - return Result::kFailViolatesAsyncCopyResource; + return AllocationResult::kFailViolatesAsyncCopyResource; } // Construct MsaBufferInterval-Chunk pairs that are appropriate for pending @@ -5791,7 +5804,7 @@ MsaAlgorithm::Result MsaAlgorithm::CheckPrefetchFit(bool for_sliced_solution, std::move(slices_for_pending_chunks), prefetch_picker_debug_string, }; - return Result::kSuccess; + return AllocationResult::kSuccess; } else if (!chunk_candidates.empty()) { // We're trying an unsliced solution. So, if FindBestChunkCandidates() found // a solution, there must be only 1 chunk for it. @@ -5802,7 +5815,7 @@ MsaAlgorithm::Result MsaAlgorithm::CheckPrefetchFit(bool for_sliced_solution, copy_resource_per_slice_sorted_by_start_time.front(), prefetch_picker_debug_string, }; - return Result::kSuccess; + return AllocationResult::kSuccess; } // Mark the out of memory start with the prefetch start time so that we don't @@ -5819,7 +5832,7 @@ MsaAlgorithm::Result MsaAlgorithm::CheckPrefetchFit(bool for_sliced_solution, } VLOG(4) << "Out of memory."; - return Result::kFailOutOfMemory; + return AllocationResult::kFailOutOfMemory; } std::string MsaAlgorithm::AlternateMemoryAllocationAttemptToString( diff --git a/xla/service/memory_space_assignment/algorithm.h b/xla/service/memory_space_assignment/algorithm.h index 756a1a6bf9433..d550114383365 100644 --- a/xla/service/memory_space_assignment/algorithm.h +++ b/xla/service/memory_space_assignment/algorithm.h @@ -48,6 +48,7 @@ limitations under the License. #include "xla/service/hlo.pb.h" #include "xla/service/hlo_value.h" #include "xla/service/memory_space_assignment/allocation.h" +#include "xla/service/memory_space_assignment/allocation_value.h" #include "xla/service/memory_space_assignment/buffer_interval_comparator.h" #include "xla/service/memory_space_assignment/memory_space_assignment.pb.h" #include "xla/service/memory_space_assignment/options.h" @@ -58,154 +59,6 @@ limitations under the License. namespace xla { namespace memory_space_assignment { - -// AllocationValue is used to break up HloValues for each non-trivial position -// (trivial positions are considered Tuple, GetTupleElement, and Bitcast). An -// HloValue may include positions and uses that alias with each other across -// multiple computations. We use this class to break these HloValues such that -// every AllocationValue has one defining position (that may alias with other -// AllocationValues). The uses field of the AllocationValue contains only the -// direct uses of the AllocationValue's defining position. -// -// For example, consider the following HLO snippet: -// -// Body { -// body_param = (f32[4,3]{1,0}, f32[]) parameter(0) -// get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element(body_param), -// index=0 -// ... -// ROOT tuple = (f32[4,3]{1,0}, f32[]) tuple(get-tuple-element.3, ...) -// } -// -// Cond { -// cond_param = (f32[4,3]{1,0}, f32[]) parameter(0) -// ... -// } -// -// add.4 = f32[4,3]{1,0} add(...) -// tuple.1 = (f32[4,3]{1,0}, f32[]) tuple(add.4, ...) -// while = (f32[4,3]{1,0}, f32[]) while(tuple.1), body=Body, condition=Cond -// get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element(while), index=0 -// add.5 = f32[4,3]{1,0} add(get-tuple-element.5, ...) -// -// This contains an HloValue that looks like the following: -// positions: -// add.4 -// body_param {0} -// get-tuple-element.3 -// tuple {0} -// cond_param {0} -// tuple.1 {0} -// while {0} -// get-tuple-element.5 -// uses: -// add.1, operand 0 -// tuple, operand 0 -// while, operand 0 {0} -// add.5, operand 0 -// -// We break this HloValue up into the following AllocationValues for each -// non-trivial position: -// AllocationValue1: computation = Entry -// position: -// add.4 -// uses: -// while, operand 0 {0} -// AllocationValue2: computation = Cond -// position: -// cond_param {0} -// uses: -// AllocationValue3: computation = Body -// position: -// body_param {0} -// uses: -// add.1, operand 0 -// tuple, operand 0 -// AllocationValue4: computation = Entry -// position: -// while {0} -// uses: -// add.5, operand 0 -class AllocationValue { - public: - // This data structure wraps an HloUse and adds additional metadata that are - // useful for allocation. - struct Use { - // The wrapped HloUse object. - HloUse hlo_use; - // The logical time this use is scheduled. - int64_t time; - // All the positions where this use aliases with. The aliased positions - // must get the same allocation. - std::vector aliases; - // A synchronous memory operation that feeds this use. - // TODO(mehrdadk): extend this to support multiple sync data movement - // operands. - HloInstruction* sync_mem_op_operand = nullptr; - - bool operator==(const Use& other) const { - return hlo_use == other.hlo_use && time == other.time && - aliases == other.aliases; - } - - template - friend H AbslHashValue(H h, const Use& s) { - return H::combine(std::move(h), s.hlo_use, s.time, s.aliases); - } - }; - - AllocationValue(const HloValue* value, const HloPosition& position, - int64_t size) - : value_(value), - defining_position_(position), - size_(size), - requires_contiguous_allocation_(false) {} - - const HloPosition& defining_position() const { return defining_position_; } - const HloInstruction* defining_instruction() const { - return defining_position().instruction; - } - int64_t size() const { return size_; } - const std::vector& uses() const { return uses_; } - std::vector& uses() { return uses_; } - const HloValue* value() const { return value_; } - const HloComputation* computation() const { - return defining_instruction()->parent(); - } - AllocationSequence* mutable_allocation_sequence() { - return &allocation_sequence_; - } - const AllocationSequence* allocation_sequence() const { - return &allocation_sequence_; - } - - // Sets/gets whether this AllocationValue requires allocating it - // contiguously throughout its live range (without any copies). - bool requires_contiguous_allocation() const { - return requires_contiguous_allocation_; - } - void set_requires_contiguous_allocation(bool requires_contiguous_allocation) { - requires_contiguous_allocation_ = requires_contiguous_allocation; - } - - void AddUse(const HloUse& use, int64_t use_time) { - uses_.push_back({use, use_time, {}}); - } - - std::string ToString() const; - std::string ToShortString() const; - - private: - const HloValue* value_; - HloPosition defining_position_; - int64_t size_; - // If true, there must be a contiguous allocation for this buffer without - // any copies. - bool requires_contiguous_allocation_; - std::vector uses_; - AllocationSequence allocation_sequence_; -}; - // A struct representing an asynchronous copy with its logical start and end // time (time that copy done is scheduled), the resource this copy would use, // its destination memory space, and a unique ID. @@ -475,89 +328,6 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { struct RepackAllocationBlock : AllocationBlock { Allocation* allocation; }; - - // A data structure we use to associate Allocation objects that are aliased - // and must get the same offset. - struct AliasedOffset { - int64_t offset; - absl::flat_hash_set allocations; - }; - - // An allocation request for a use segment. A use segment is the time segment - // between the definition and the first use, and the time segment between the - // uses of a buffer. For example, the time between the definition and Use1, is - // the first segment, and the time between Use1 and Use2 is the second segment - // and so on: - // - // +------+----------+-------+ - // / \ \ \ - // / v v v - // Def Use1 Use2 Use3 - // <----------> <--------> <-----> - // Segment Segment Segment - // - // start_time and end_time are the start and end logical times of the segment. - // use_times is a sorted sequence of the times of all uses. - // latest_prefetch_time is the latest time we can schedule the CopyDone for a - // prefetch. - // If allow_no_copy_alternate_mem_allocation is false, an eviction is forced. - // If earliest_prefetch_time is set, prefetches cannot start before this - // value. - // - // In case we are trying to replace synchronous copies, and for example Use2 - // is a replaceable sync copy candidate, we now skip Use2 and segments will be - // between Def, Use1, Use2.1, Use2.2, Use3: - // +------+----------+-------------------+ - // / \ \ \ - // / v v v - // Def Use1 Use2(Sync Copy) Use3 - // | | \ \ | - // | | v v | - // | | Use2.1 Use2.2 | - // |<---------->|<---------->|<------->|<----->| - // | Segment | Segment | Segment |Segment| - - struct AllocationRequest { - int64_t inclusive_start_time; - int64_t end_time; - int64_t latest_prefetch_time; - // See the comment for require_copy_allocation - int64_t required_copy_allocation_latest_time; - int64_t size; - bool prefer_no_copy_alternate_mem_allocation; - bool allow_no_copy_alternate_mem_allocation; - bool require_no_copy_alternate_mem_allocation; - // If true, indicates we are requiring a copy allocation between def and - // use, that finishes by required_copy_allocation_latest_time. - // required_copy_allocation_for is a synchronous copy instruction that will - // be removed, if we are successful in adding the copy allocation. - bool require_copy_allocation; - bool allow_prefetch; - std::optional earliest_prefetch_time; - std::optional preferred_prefetch_time; - AliasedOffset* preferred_offset; - const AllocationValue::Use* use; - AllocationValue* allocation_value; - absl::Span all_use_times; - // See the comment for require_copy_allocation - HloInstruction* required_copy_allocation_for; - // If the required copy in require_copy_allocation is only for a slice of - // the allocation_value - bool required_copy_for_slice; - // The resulting Allocation will be added to the AllocationSequence of - // allocation_value_to_update. We only expect allocation_value_to_update to - // be different from allocation_value in the case of a synchronous memory - // operation conversion to asynchronous, otherwise, they should be the same. - AllocationValue* allocation_value_to_update; - // No new Allocation is needed to be created and we will only extend an - // existing one. - bool only_extend_existing_allocation; - // Data structure that contains the options for making window prefetched - // allocations. - const WindowPrefetchedAllocation::Options* window_prefetch_options = - nullptr; - }; - // This struct contains mandatory memory assignments at a given time. E.g., an // input's required memory assignment time would correspond to the definition // time of the parameter instruction, and an output's time would correspond to @@ -718,75 +488,39 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { bool window_prefetch = false; }; - // Result of an allocation, prefetch, eviction etc. request. The result is - // either kSuccess or a bitwise OR of one or more failures. The values are - // unique powers of two. To check if a result contains a particular failure, - // use the result_is method. To add a new failure to a result, use the - // result_mark method. - enum class Result { - // Successful allocation. - kSuccess = 0, - // Allocation failed because we ran out of alternate memory. - kFailOutOfMemory = 1, - // A no-copy allocation couldn't be performed because the previous - // allocation wasn't in the alternate memory space. - kFailPrevAllocationNotInAlternateMem = 2, - // A no-copy allocation couldn't be performed because the live range was too - // long. - kFailLiveRangeTooLong = 4, - // A prefetching couldn't be performed because the live range was too short. - kFailLiveRangeTooShort = 8, - // Ran out of outstanding asynchronous copy limit either during prefetching - // or eviction. - kFailOutOfAsyncCopies = 16, - // A prefetching couldn't be performed because the asynchronous copy - // resource was violated. - kFailViolatesAsyncCopyResource = 32, - // An allocation failure happened that requires uncommitting all the pending - // allocations. Usually this is due to a situation requiring an eviction but - // the eviction couldn't be performed. - kFailRequiresUncommit = 64, - // For prefetching, indicates that all slices have the same start time, in - // which case, we fallback to an unsliced solution. - kAllSlicesHaveTheSameStartTime = 128, - // There were conflicting preferred offsets. - kFailConflictingPreferredOffsets = 256, - // Could not replace the synchronous data movement instruction (e.g., kCopy, - // kSlice) with an asynchronous one - kFailSyncDataMoveReplacement = 512 - }; - // Return true if the result belongs to a failure. - static bool result_is(Result result, Result failure) { + static bool result_is(AllocationResult result, AllocationResult failure) { return static_cast(result) & static_cast(failure); } // Mark (bitwise OR) a failure to the result. - static Result result_mark(Result failure, Result& result) { - result = static_cast(static_cast(result) | - static_cast(failure)); + static AllocationResult result_mark(AllocationResult failure, + AllocationResult& result) { + result = static_cast(static_cast(result) | + static_cast(failure)); return result; } // Return a string representation of a result that has at most a single // failure. Consider using ResultToString for a general case. - static std::string SingleFailureResultToString(const Result& result); + static std::string SingleFailureResultToString( + const AllocationResult& result); // Return a string representation of the result, with possibly more than one // failure. - static std::string ResultToString(const Result& result); + static std::string ResultToString(const AllocationResult& result); // Return true if the result is a failure that requires us to uncommit pending // chunks. - static bool result_requires_uncommit(Result result) { - return result_is(result, Result::kFailRequiresUncommit); + static bool result_requires_uncommit(AllocationResult result) { + return result_is(result, AllocationResult::kFailRequiresUncommit); } // Return true if the result is a failure either due to running out of // outstanding asynchronous copies or due to violating asynchronous copy // ordering. - static bool result_failed_because_of_async_copy(Result result) { - return result_is(result, Result::kFailOutOfAsyncCopies) || - result_is(result, Result::kFailViolatesAsyncCopyResource); + static bool result_failed_because_of_async_copy(AllocationResult result) { + return result_is(result, AllocationResult::kFailOutOfAsyncCopies) || + result_is(result, AllocationResult::kFailViolatesAsyncCopyResource); } // For the given loop with the start and end index and loop size, run the @@ -916,7 +650,7 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { // All of the allocation values have a must-alias relationship with each // other. Returns either kSuccess if all of the sites could be placed in the // alternate memory or a bitwise OR of failure reasons why they couldn't - absl::StatusOr AllocateAllocationValues( + absl::StatusOr AllocateAllocationValues( absl::Span allocation_values); // Finds an allocation for an allocation request for a segment (see the @@ -937,22 +671,23 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { // Result::kSuccess if the buffer could be placed in alternate memory or some // other Result with an OR of reasons why the buffer couldn't be placed in // alternate memory. - Result AllocateSegment(AllocationRequest& request); + AllocationResult AllocateSegment(AllocationRequest& request); // Try allocating in alternate memory without any copies. - Result AllocateInAlternateMemoryNoCopy(const AllocationRequest& request); + AllocationResult AllocateInAlternateMemoryNoCopy( + const AllocationRequest& request); // Try evicting to default memory space. - Result Evict(const AllocationRequest& request); + AllocationResult Evict(const AllocationRequest& request); // Returns the time a copy done of a prefetch should be scheduled. int64_t FindPrefetchEndTime(const AllocationRequest& request, int64_t earliest_prefetch_time) const; // Try prefetching to alternate memory space. - Result Prefetch(const AllocationRequest& request, - Allocation& prev_allocation_in_default_mem, - const Shape* shape = nullptr); + AllocationResult Prefetch(const AllocationRequest& request, + Allocation& prev_allocation_in_default_mem, + const Shape* shape = nullptr); // Helper methods used to implement Prefetch(). // @@ -966,14 +701,16 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { PrefetchContext& context) const; // Initializes the PrefetchIntervalPicker and associated data structures in // context. - Result InitializePrefetchIntervalPicker(PrefetchContext& context); + AllocationResult InitializePrefetchIntervalPicker(PrefetchContext& context); // As a compile time optimization, try a prefetch allocation that is as late // as possible. If this is not able to find a solution, none of the // earlier tries will succeed either. - Result EnsureSomeSpatialPrefetchFitExists(PrefetchContext& context) const; + AllocationResult EnsureSomeSpatialPrefetchFitExists( + PrefetchContext& context) const; // Check if for the specified type of solution, using the parameters in // context. If we find a solution, it will be stored in context. - Result CheckPrefetchFit(bool for_sliced_solution, PrefetchContext& context); + AllocationResult CheckPrefetchFit(bool for_sliced_solution, + PrefetchContext& context); // Creates a debugging string describing the timing of the prefetch solution // we are currently attempting (as dictated by for_sliced_solution and // context). @@ -981,8 +718,8 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { bool for_sliced_solution, const PrefetchContext& context) const; // Try to prefetch a window worth of data into the alternate memory. - Result WindowPrefetch(const AllocationRequest& request, - Allocation& prev_allocation_in_default_mem); + AllocationResult WindowPrefetch(const AllocationRequest& request, + Allocation& prev_allocation_in_default_mem); // Find the best possible chunk candidate, where it has the longest possible // availability if no preferred offset is given, or at the preferred_offset if diff --git a/xla/service/memory_space_assignment/allocation_value.h b/xla/service/memory_space_assignment/allocation_value.h new file mode 100644 index 0000000000000..55cd8c9990736 --- /dev/null +++ b/xla/service/memory_space_assignment/allocation_value.h @@ -0,0 +1,300 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_ALLOCATION_VALUE_H_ +#define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_ALLOCATION_VALUE_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/hlo_value.h" +#include "xla/service/memory_space_assignment/allocation.h" + +namespace xla { +namespace memory_space_assignment { +// AllocationValue is used to break up HloValues for each non-trivial position +// (trivial positions are considered Tuple, GetTupleElement, and Bitcast). An +// HloValue may include positions and uses that alias with each other across +// multiple computations. We use this class to break these HloValues such that +// every AllocationValue has one defining position (that may alias with other +// AllocationValues). The uses field of the AllocationValue contains only the +// direct uses of the AllocationValue's defining position. +// +// For example, consider the following HLO snippet: +// +// Body { +// body_param = (f32[4,3]{1,0}, f32[]) parameter(0) +// get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element(body_param), +// index=0 +// ... +// ROOT tuple = (f32[4,3]{1,0}, f32[]) tuple(get-tuple-element.3, ...) +// } +// +// Cond { +// cond_param = (f32[4,3]{1,0}, f32[]) parameter(0) +// ... +// } +// +// add.4 = f32[4,3]{1,0} add(...) +// tuple.1 = (f32[4,3]{1,0}, f32[]) tuple(add.4, ...) +// while = (f32[4,3]{1,0}, f32[]) while(tuple.1), body=Body, condition=Cond +// get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element(while), index=0 +// add.5 = f32[4,3]{1,0} add(get-tuple-element.5, ...) +// +// This contains an HloValue that looks like the following: +// positions: +// add.4 +// body_param {0} +// get-tuple-element.3 +// tuple {0} +// cond_param {0} +// tuple.1 {0} +// while {0} +// get-tuple-element.5 +// uses: +// add.1, operand 0 +// tuple, operand 0 +// while, operand 0 {0} +// add.5, operand 0 +// +// We break this HloValue up into the following AllocationValues for each +// non-trivial position: +// AllocationValue1: computation = Entry +// position: +// add.4 +// uses: +// while, operand 0 {0} +// AllocationValue2: computation = Cond +// position: +// cond_param {0} +// uses: +// AllocationValue3: computation = Body +// position: +// body_param {0} +// uses: +// add.1, operand 0 +// tuple, operand 0 +// AllocationValue4: computation = Entry +// position: +// while {0} +// uses: +// add.5, operand 0 +class AllocationValue { + public: + // This data structure wraps an HloUse and adds additional metadata that are + // useful for allocation. + struct Use { + // The wrapped HloUse object. + HloUse hlo_use; + // The logical time this use is scheduled. + int64_t time; + // All the positions where this use aliases with. The aliased positions + // must get the same allocation. + std::vector aliases; + // A synchronous memory operation that feeds this use. + // TODO(mehrdadk): extend this to support multiple sync data movement + // operands. + HloInstruction* sync_mem_op_operand = nullptr; + + bool operator==(const Use& other) const { + return hlo_use == other.hlo_use && time == other.time && + aliases == other.aliases; + } + + template + friend H AbslHashValue(H h, const Use& s) { + return H::combine(std::move(h), s.hlo_use, s.time, s.aliases); + } + }; + + AllocationValue(const HloValue* value, const HloPosition& position, + int64_t size) + : value_(value), + defining_position_(position), + size_(size), + requires_contiguous_allocation_(false) {} + + const HloPosition& defining_position() const { return defining_position_; } + const HloInstruction* defining_instruction() const { + return defining_position().instruction; + } + int64_t size() const { return size_; } + const std::vector& uses() const { return uses_; } + std::vector& uses() { return uses_; } + const HloValue* value() const { return value_; } + const HloComputation* computation() const { + return defining_instruction()->parent(); + } + AllocationSequence* mutable_allocation_sequence() { + return &allocation_sequence_; + } + const AllocationSequence* allocation_sequence() const { + return &allocation_sequence_; + } + + // Sets/gets whether this AllocationValue requires allocating it + // contiguously throughout its live range (without any copies). + bool requires_contiguous_allocation() const { + return requires_contiguous_allocation_; + } + void set_requires_contiguous_allocation(bool requires_contiguous_allocation) { + requires_contiguous_allocation_ = requires_contiguous_allocation; + } + + void AddUse(const HloUse& use, int64_t use_time) { + uses_.push_back({use, use_time, {}}); + } + + std::string ToString() const; + std::string ToShortString() const; + + private: + const HloValue* value_; + HloPosition defining_position_; + int64_t size_; + // If true, there must be a contiguous allocation for this buffer without + // any copies. + bool requires_contiguous_allocation_; + std::vector uses_; + AllocationSequence allocation_sequence_; +}; + +// A data structure we use to associate Allocation objects that are aliased +// and must get the same offset. +struct AliasedOffset { + int64_t offset; + absl::flat_hash_set allocations; +}; + +// An allocation request for a use segment. A use segment is the time segment +// between the definition and the first use, and the time segment between the +// uses of a buffer. For example, the time between the definition and Use1, is +// the first segment, and the time between Use1 and Use2 is the second segment +// and so on: +// +// +------+----------+-------+ +// / \ \ \ +// / v v v +// Def Use1 Use2 Use3 +// <----------> <--------> <-----> +// Segment Segment Segment +// +// start_time and end_time are the start and end logical times of the segment. +// use_times is a sorted sequence of the times of all uses. +// latest_prefetch_time is the latest time we can schedule the CopyDone for a +// prefetch. +// If allow_no_copy_alternate_mem_allocation is false, an eviction is forced. +// If earliest_prefetch_time is set, prefetches cannot start before this +// value. +// +// In case we are trying to replace synchronous copies, and for example Use2 +// is a replaceable sync copy candidate, we now skip Use2 and segments will be +// between Def, Use1, Use2.1, Use2.2, Use3: +// +------+----------+-------------------+ +// / \ \ \ + // / v v v +// Def Use1 Use2(Sync Copy) Use3 +// | | \ \ | +// | | v v | +// | | Use2.1 Use2.2 | +// |<---------->|<---------->|<------->|<----->| +// | Segment | Segment | Segment |Segment| + +struct AllocationRequest { + int64_t inclusive_start_time; + int64_t end_time; + int64_t latest_prefetch_time; + // See the comment for require_copy_allocation + int64_t required_copy_allocation_latest_time; + int64_t size; + bool prefer_no_copy_alternate_mem_allocation; + bool allow_no_copy_alternate_mem_allocation; + bool require_no_copy_alternate_mem_allocation; + // If true, indicates we are requiring a copy allocation between def and + // use, that finishes by required_copy_allocation_latest_time. + // required_copy_allocation_for is a synchronous copy instruction that will + // be removed, if we are successful in adding the copy allocation. + bool require_copy_allocation; + bool allow_prefetch; + std::optional earliest_prefetch_time; + std::optional preferred_prefetch_time; + AliasedOffset* preferred_offset; + const AllocationValue::Use* use; + AllocationValue* allocation_value; + absl::Span all_use_times; + // See the comment for require_copy_allocation + HloInstruction* required_copy_allocation_for; + // If the required copy in require_copy_allocation is only for a slice of + // the allocation_value + bool required_copy_for_slice; + // The resulting Allocation will be added to the AllocationSequence of + // allocation_value_to_update. We only expect allocation_value_to_update to + // be different from allocation_value in the case of a synchronous memory + // operation conversion to asynchronous, otherwise, they should be the same. + AllocationValue* allocation_value_to_update; + // No new Allocation is needed to be created and we will only extend an + // existing one. + bool only_extend_existing_allocation; + // Data structure that contains the options for making window prefetched + // allocations. + const WindowPrefetchedAllocation::Options* window_prefetch_options = nullptr; +}; + +// Result of an allocation, prefetch, eviction etc. request. The result is +// either kSuccess or a bitwise OR of one or more failures. The values are +// unique powers of two. To check if a result contains a particular failure, +// use the result_is method. To add a new failure to a result, use the +// result_mark method. +enum class AllocationResult { + // Successful allocation. + kSuccess = 0, + // Allocation failed because we ran out of alternate memory. + kFailOutOfMemory = 1, + // A no-copy allocation couldn't be performed because the previous + // allocation wasn't in the alternate memory space. + kFailPrevAllocationNotInAlternateMem = 2, + // A no-copy allocation couldn't be performed because the live range was too + // long. + kFailLiveRangeTooLong = 4, + // A prefetching couldn't be performed because the live range was too short. + kFailLiveRangeTooShort = 8, + // Ran out of outstanding asynchronous copy limit either during prefetching + // or eviction. + kFailOutOfAsyncCopies = 16, + // A prefetching couldn't be performed because the asynchronous copy + // resource was violated. + kFailViolatesAsyncCopyResource = 32, + // An allocation failure happened that requires uncommitting all the pending + // allocations. Usually this is due to a situation requiring an eviction but + // the eviction couldn't be performed. + kFailRequiresUncommit = 64, + // For prefetching, indicates that all slices have the same start time, in + // which case, we fallback to an unsliced solution. + kAllSlicesHaveTheSameStartTime = 128, + // There were conflicting preferred offsets. + kFailConflictingPreferredOffsets = 256, + // Could not replace the synchronous data movement instruction (e.g., kCopy, + // kSlice) with an asynchronous one + kFailSyncDataMoveReplacement = 512 +}; + +} // namespace memory_space_assignment +} // namespace xla +#endif // XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_ALLOCATION_VALUE_H_ diff --git a/xla/service/memory_space_assignment/memory_space_assignment.cc b/xla/service/memory_space_assignment/memory_space_assignment.cc index 8902ac165f2d6..5216d08860d66 100644 --- a/xla/service/memory_space_assignment/memory_space_assignment.cc +++ b/xla/service/memory_space_assignment/memory_space_assignment.cc @@ -347,6 +347,37 @@ MemorySpaceAssignment::Run(HloModule* module, alias_analysis); } +absl::Status MemorySpaceAssignment::VerifyAllocations() const { + BufferIntervalTree interval_tree; + // Checks the chunks that overlap with a given allocation in time do not + // overlap with the allocation's chunk in the memory range. If they do, we + // throw an error, otherwise we add the allocation's chunk to the interval + // tree and return an OK status. + auto add_allocation_and_verify = + [&](const Allocation* allocation) -> absl::Status { + for (const HeapSimulator::Chunk& overlapping_chunk : + interval_tree.ChunksOverlappingInTime(allocation->start_time(), + allocation->end_time() - 1)) { + CHECK(!allocation->chunk().OverlapsWith(overlapping_chunk)) + << "Chunks are overlapping at Allocation level (before fixing the " + "schedule): " + << allocation->ToString() + << " overlaps with allocated chunk: " << overlapping_chunk.ToString(); + } + interval_tree.Add(allocation->start_time(), allocation->end_time() - 1, + allocation->chunk()); + return absl::OkStatus(); + }; + // Verify that all alternate memory allocations are free of overlapping + // Allocations in time and space, and add them to interval_tree one by one. + for (const auto& allocation : allocations_) { + if (allocation->memory_space() == MemorySpace::kAlternate) { + TF_RETURN_IF_ERROR(add_allocation_and_verify(allocation.get())); + } + } + return absl::OkStatus(); +} + absl::StatusOr> MemorySpaceAssignment::RunMemorySpaceAssignment( const HloLiveRange& hlo_live_range, @@ -365,6 +396,9 @@ MemorySpaceAssignment::RunMemorySpaceAssignment( } TF_RETURN_IF_ERROR(Process(hlo_live_range)); + if (options_.verify) { + TF_RETURN_IF_ERROR(VerifyAllocations()); + } // DEBUG_LOG_ALLOCATIONS_AT // // Uncomment the following to log the alternate memory allocations that MSA diff --git a/xla/service/memory_space_assignment/memory_space_assignment.h b/xla/service/memory_space_assignment/memory_space_assignment.h index 9903c1b32a4a4..e2ff35441e4d5 100644 --- a/xla/service/memory_space_assignment/memory_space_assignment.h +++ b/xla/service/memory_space_assignment/memory_space_assignment.h @@ -307,6 +307,11 @@ class MemorySpaceAssignment { // Calculates asynchronous copy statistics. absl::StatusOr CalculateAsyncCopyStats() const; + // Verify that allocations_ are free of overlapping Allocations in time and + // space. This is a post-processing step called after all allocations have + // been finalized, before the async copies get scheduled. + absl::Status VerifyAllocations() const; + // Verify that the memory space assignment is free of overlapping buffers and // export heap simulator trace to be used by buffer_assignment. // diff --git a/xla/service/memory_space_assignment/memory_space_assignment_test.cc b/xla/service/memory_space_assignment/memory_space_assignment_test.cc index 4b3f03f7bffad..647e52baf552d 100644 --- a/xla/service/memory_space_assignment/memory_space_assignment_test.cc +++ b/xla/service/memory_space_assignment/memory_space_assignment_test.cc @@ -18,11 +18,12 @@ limitations under the License. #include #include #include +#include #include #include -#include #include #include +#include #include #include #include @@ -65,6 +66,7 @@ limitations under the License. #include "xla/service/hlo_value.h" #include "xla/service/memory_space_assignment/algorithm.h" #include "xla/service/memory_space_assignment/allocation.h" +#include "xla/service/memory_space_assignment/allocation_value.h" #include "xla/service/memory_space_assignment/buffer_interval_comparator.h" #include "xla/service/memory_space_assignment/cost_analysis.h" #include "xla/service/memory_space_assignment/memory_space_assignment.pb.h" @@ -77,7 +79,6 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_utils.h" -#include "xla/tests/verified_hlo_module.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -161,7 +162,7 @@ class MemorySpaceAssignmentTestBase : public HloTestBase { Options options; options.max_size_in_bytes = 128; options.alignment_in_bytes = 8; - options.verify = true; + options.verify = false; options.alternate_memory_space = kAlternateMemorySpace; options.max_outstanding_prefetches = -1; options.max_outstanding_evictions = -1; @@ -1215,6 +1216,168 @@ TEST_F(MemorySpaceAssignmentTest, ConditionalCopyReplacement) { op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, p0)); } +TEST_F(MemorySpaceAssignmentTest, AllocationRequestAndResultModifierTest) { + absl::string_view hlo_string = R"( +HloModule module, is_scheduled=true + +ENTRY entry { + p0 = f32[2,3]{1,0} parameter(0) + p1 = f32[2,3]{1,0} parameter(1) + negate0 = f32[2,3]{1,0} negate(p1) + negate1 = f32[2,3]{1,0} negate(negate0) + negate2 = f32[2,3]{1,0} negate(negate1) + negate3 = f32[2,3]{1,0} negate(negate2) + negate4 = f32[2,3]{1,0} negate(negate3) + negate5 = f32[2,3]{1,0} negate(negate4) + negate6 = f32[2,3]{1,0} negate(negate5) + negate7 = f32[2,3]{1,0} negate(negate6) + ROOT add0 = f32[2,3]{1,0} add(p0, negate7) + } + )"; + // The baseline behavior is to prefetch p0 at add0. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr baseline_module, + ParseAndReturnVerifiedModule(hlo_string)); + Options options = DefaultMemorySpaceOptions(); + AssignMemorySpace(baseline_module.get(), options); + HloInstruction* add0 = FindInstruction(baseline_module.get(), "add0"); + ASSERT_NE(add0, nullptr); + HloInstruction* p0 = FindInstruction(baseline_module.get(), "p0"); + ASSERT_NE(p0, nullptr); + EXPECT_THAT(add0->operand(0), + op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, p0)); + + // We should be able to prevent prefetching p0 at add0 using + // allocation_result_modifier_testing_fn. + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr result_modifier_module, + ParseAndReturnVerifiedModule(hlo_string)); + options.max_retries = 1; + options.allocation_request_modifier_testing_fn = nullptr; + options.allocation_result_modifier_testing_fn = + [](const AllocationRequest& request, AllocationResult& result) { + if (request.allocation_value_to_update->defining_instruction() + ->name() == "p0" && + request.use->hlo_use.instruction->name() == "add0") { + result = AllocationResult::kFailRequiresUncommit; + } + }; + AssignMemorySpace(result_modifier_module.get(), options); + add0 = FindInstruction(result_modifier_module.get(), "add0"); + ASSERT_NE(add0, nullptr); + p0 = FindInstruction(result_modifier_module.get(), "p0"); + ASSERT_NE(p0, nullptr); + EXPECT_EQ(add0->operand(0), p0); + + // We should be able to enforce an earlier prefetch of p0 at add0 using + // allocation_request_modifier_testing_fn. + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr request_modifier_module, + ParseAndReturnVerifiedModule(hlo_string)); + options.max_retries = 1; + options + .allocation_request_modifier_testing_fn = [](AllocationRequest& request) { + if (request.allocation_value_to_update->defining_instruction()->name() == + "p0" && + request.use->hlo_use.instruction->name() == "add0") { + // Schedule the copy-done before negate4 (scheduled at 6). + request.latest_prefetch_time = 6; + } + }; + options.allocation_result_modifier_testing_fn = nullptr; + AssignMemorySpace(request_modifier_module.get(), options); + add0 = FindInstruction(request_modifier_module.get(), "add0"); + CHECK_NE(add0, nullptr); + p0 = FindInstruction(request_modifier_module.get(), "p0"); + CHECK_NE(p0, nullptr); + EXPECT_THAT(add0->operand(0), + op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, p0)); + // The copy-done should have been scheduled before negate4. + HloInstruction* negate4 = + FindInstruction(request_modifier_module.get(), "negate4"); + CHECK_NE(negate4, nullptr); + const HloInstructionSequence& sequence = + request_modifier_module->schedule().sequence( + request_modifier_module->entry_computation()); + auto find_index = [&](const HloInstruction* instruction) { + return std::distance(sequence.instructions().begin(), + std::find(sequence.instructions().begin(), + sequence.instructions().end(), instruction)); + }; + + int negate4_index = find_index(negate4); + int copy_done_index = find_index(add0->operand(0)); + EXPECT_LT(copy_done_index, negate4_index); +} + +// Added for b/376869021, which surfaced when we tried to convert a sync slice +// that had to extend the allocation of its operand in the alternate memory. In +// this test, we expect the slice0 operand (p0_copy) maintain a valid allocation +// in the alternate memory, until it gets transferred by the async replacement +// of slice0. We hence stress-test such validity by delaying the allocation of +// slice0 by 3 steps. +TEST_F(MemorySpaceAssignmentTest, SyncReplacementAllocationExtensionBug) { + absl::string_view hlo_string = R"( +HloModule module, is_scheduled=true + +ENTRY entry { + p0 = f32[2,2,3]{2,1,0} parameter(0) + p1 = f32[4,2,3]{2,1,0} parameter(1) + p0_copy = f32[2,2,3]{2,1,0} copy(p0) + negate0 = negate(p1) + negate1 = negate(negate0) + negate2 = negate(negate1) + p0_copy0_negate = negate(p0_copy) + copy_negate2 = copy(negate2) + slice0 = f32[1,2,3] slice(p0_copy), slice={[0:1], [0:2], [0:3]} + negate3 = negate(copy_negate2) + negate4 = negate(negate3) + negate5 = negate(negate4) + negate6 = negate(negate5) + negate7 = negate(negate6) + neg_slice0 = negate(slice0) + ROOT tuple = tuple(negate7, neg_slice0) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + Options options = DefaultMemorySpaceOptions(); + options.enable_sync_copy_replacement = false; + options.enable_sync_slice_replacement = true; + options.verify = true; + options.is_async_slice_implemented_fn = + [](const HloInstruction* instruction) { return true; }; + options.max_size_in_bytes = 96; + options.is_position_allowed_in_alternate_mem_fn = + [](const HloPosition& position) { + return position.instruction->name() != "p0_copy"; + }; + // Delay the allocation of slice0 by 3 steps to allow copy_negate2 to be + // allocated in alternate memory. + options.allocation_request_modifier_testing_fn = + [](AllocationRequest& request) { + if (request.only_extend_existing_allocation) { + request.inclusive_start_time += 3; + request.end_time += 3; + } + }; + const std::string text_proto = R"pb( + overrides { + hlo_position_matcher { instruction_name_regex: "copy_negate2|p0_copy" } + override_options { assign_first: true } + })pb"; + TF_ASSERT_OK_AND_ASSIGN(auto msa_sort_order_overrides, + ParseTextProto(text_proto)); + auto preset_assignments = AssignMemorySpaceUsingCostAnalysis( + module.get(), options, + /*cost_analysis_options_override=*/std::nullopt, + /*hlo_cost_options_override=*/std::nullopt, + /*optional_msa_sort_order_overrides=*/msa_sort_order_overrides); + HloInstruction* p0_copy = FindInstruction(module.get(), "p0_copy"); + ASSERT_NE(p0_copy, nullptr); + HloInstruction* neg_slice0 = FindInstruction(module.get(), "neg_slice0"); + ASSERT_NE(neg_slice0, nullptr); + EXPECT_THAT(neg_slice0->operand(0), op::AsyncDone(op::AsyncStart(p0_copy))); +} + TEST_F(MemorySpaceAssignmentTest, AlwaysSpillJitPrefetchTest) { // The negate chain is long enough for asynchronous copy to be inserted // between p1 and add. diff --git a/xla/service/memory_space_assignment/options.h b/xla/service/memory_space_assignment/options.h index 48799ae046cea..075b6016fc9cd 100644 --- a/xla/service/memory_space_assignment/options.h +++ b/xla/service/memory_space_assignment/options.h @@ -30,8 +30,8 @@ limitations under the License. #include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/buffer_value.h" -#include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/hlo_value.h" +#include "xla/service/memory_space_assignment/allocation_value.h" #include "xla/service/memory_space_assignment/buffer_interval_comparator.h" #include "xla/service/memory_space_assignment/cost_analysis.h" #include "xla/service/memory_space_assignment/memory_space_assignment.pb.h" @@ -128,9 +128,23 @@ struct Options { WindowPrefetchNotifyOperandAppendedFunction notify_operand_appended_fn = [](HloInstruction*, int64_t, int64_t) {}; + // This function can be used to check if an equivalent asynchronous slice + // lowering is implemented for a given synchronous slice instruction. IsAsyncSliceImplementedFunction is_async_slice_implemented_fn = [](const HloInstruction*) { return false; }; + // Should only be used for testing purposes. This function allows us to + // modify the AllocationResult after the AllocationRequest has been processed + // by AllocateSegment(). + std::function + allocation_result_modifier_testing_fn = nullptr; + + // Should only be used for testing purposes. This function allows us to + // modify the AllocationRequest before the AllocationRequest is passed to + // AllocateSegment(). + std::function + allocation_request_modifier_testing_fn = nullptr; + // If true, we will try to reduce scoped allocation buffer size for all // instructions if their operand/output has been allocated in alternate // memory.