Skip to content

Commit

Permalink
Add device remove handler to the KSA transport
Browse files Browse the repository at this point in the history
This should fix #403 for cases when a device is updated to the new MIDI 2.0 driver after it has already been picked up by the KSA transport.
  • Loading branch information
Psychlist1972 committed Oct 7, 2024
1 parent 46c49ad commit ff8ef86
Show file tree
Hide file tree
Showing 9 changed files with 50 additions and 51 deletions.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified build/staging/app-sdk/Arm64EC/mididiag.exe
Binary file not shown.
2 changes: 1 addition & 1 deletion build/staging/version/BundleInfo.wxi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<Include>
<?define SetupVersionName="Developer Preview 7 Arm64" ?>
<?define SetupVersionNumber="1.0.24260.2222" ?>
<?define SetupVersionNumber="1.0.24281.1756" ?>
</Include>
Original file line number Diff line number Diff line change
Expand Up @@ -951,8 +951,8 @@ CMidi2KSMidiEndpointManager::Cleanup()
m_DeviceStopped.revoke();
m_DeviceEnumerationCompleted.revoke();

m_MidiDeviceManager.reset();
m_MidiProtocolManager.reset();
m_MidiDeviceManager.reset();
m_MidiProtocolManager.reset();

return S_OK;
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ CMidi2KSAggregateMidiEndpointManager::Initialize(
_Use_decl_annotations_
HRESULT
CMidi2KSAggregateMidiEndpointManager::CreateMidiUmpEndpoint(
KsAggregateEndpointDefinition& MasterEndpointDefinition
KsAggregateEndpointDefinition& masterEndpointDefinition
)
{
TraceLoggingWrite(
Expand All @@ -96,21 +96,21 @@ CMidi2KSAggregateMidiEndpointManager::CreateMidiUmpEndpoint(
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
TraceLoggingPointer(this, "this"),
TraceLoggingWideString(L"Creating aggregate UMP endpoint", MIDI_TRACE_EVENT_MESSAGE_FIELD),
TraceLoggingWideString(MasterEndpointDefinition.EndpointName.c_str(), "name")
TraceLoggingWideString(masterEndpointDefinition.EndpointName.c_str(), "name")
);


// we require at least one valid pin
RETURN_HR_IF(E_INVALIDARG, MasterEndpointDefinition.Pins.size() < 1);
RETURN_HR_IF(E_INVALIDARG, masterEndpointDefinition.Pins.size() < 1);

std::vector<DEVPROPERTY> interfaceDevProperties;

MIDIENDPOINTCOMMONPROPERTIES commonProperties{};
commonProperties.AbstractionLayerGuid = ABSTRACTION_LAYER_GUID;
commonProperties.EndpointPurpose = MidiEndpointDevicePurposePropertyValue::NormalMessageEndpoint;
commonProperties.FriendlyName = MasterEndpointDefinition.EndpointName.c_str();
commonProperties.FriendlyName = masterEndpointDefinition.EndpointName.c_str();
commonProperties.TransportCode = TRANSPORT_CODE;
commonProperties.TransportSuppliedEndpointName = MasterEndpointDefinition.FilterName.c_str();
commonProperties.TransportSuppliedEndpointName = masterEndpointDefinition.FilterName.c_str();
commonProperties.TransportSuppliedEndpointDescription = nullptr;
commonProperties.UserSuppliedEndpointName = nullptr;
commonProperties.UserSuppliedEndpointDescription = nullptr;
Expand All @@ -126,10 +126,10 @@ CMidi2KSAggregateMidiEndpointManager::CreateMidiUmpEndpoint(
commonProperties.SupportsMidi2ProtocolDefaultValue = false;

interfaceDevProperties.push_back({ {DEVPKEY_KsMidiPort_KsFilterInterfaceId, DEVPROP_STORE_SYSTEM, nullptr},
DEVPROP_TYPE_STRING, static_cast<ULONG>((MasterEndpointDefinition.FilterDeviceId.length() + 1) * sizeof(WCHAR)), (PVOID)MasterEndpointDefinition.FilterDeviceId.c_str() });
DEVPROP_TYPE_STRING, static_cast<ULONG>((masterEndpointDefinition.FilterDeviceId.length() + 1) * sizeof(WCHAR)), (PVOID)masterEndpointDefinition.FilterDeviceId.c_str() });

interfaceDevProperties.push_back({ {DEVPKEY_KsTransport, DEVPROP_STORE_SYSTEM, nullptr },
DEVPROP_TYPE_UINT32, static_cast<ULONG>(sizeof(UINT32)), (PVOID)&MasterEndpointDefinition.TransportCapability });
DEVPROP_TYPE_UINT32, static_cast<ULONG>(sizeof(UINT32)), (PVOID)&masterEndpointDefinition.TransportCapability });

// create group terminal blocks and the pin map

Expand All @@ -140,7 +140,7 @@ CMidi2KSAggregateMidiEndpointManager::CreateMidiUmpEndpoint(
KSMIDI_PIN_MAP pinMap{ };
std::vector<internal::GroupTerminalBlockInternal> blocks;

for (auto const& pin : MasterEndpointDefinition.Pins)
for (auto const& pin : masterEndpointDefinition.Pins)
{
internal::GroupTerminalBlockInternal gtb;

Expand Down Expand Up @@ -218,9 +218,9 @@ CMidi2KSAggregateMidiEndpointManager::CreateMidiUmpEndpoint(
SW_DEVICE_CREATE_INFO createInfo{ };

createInfo.cbSize = sizeof(createInfo);
createInfo.pszInstanceId = MasterEndpointDefinition.EndpointDeviceInstanceId.c_str();
createInfo.pszInstanceId = masterEndpointDefinition.EndpointDeviceInstanceId.c_str();
createInfo.CapabilityFlags = SWDeviceCapabilitiesNone;
createInfo.pszDeviceDescription = MasterEndpointDefinition.EndpointName.c_str();
createInfo.pszDeviceDescription = masterEndpointDefinition.EndpointName.c_str();

// Call the device manager and finish the creation

Expand All @@ -231,7 +231,7 @@ CMidi2KSAggregateMidiEndpointManager::CreateMidiUmpEndpoint(

LOG_IF_FAILED(
swdCreationResult = m_MidiDeviceManager->ActivateEndpoint(
MasterEndpointDefinition.ParentDeviceInstanceId.c_str(),
masterEndpointDefinition.ParentDeviceInstanceId.c_str(),
false, // TODO: create UMP only, handle the MIDI 1.0 compat ports
MidiFlow::MidiFlowBidirectional, // bidi only for the UMP endpoint
&commonProperties,
Expand All @@ -254,13 +254,18 @@ CMidi2KSAggregateMidiEndpointManager::CreateMidiUmpEndpoint(
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
TraceLoggingPointer(this, "this"),
TraceLoggingWideString(L"Aggregate UMP endpoint created", MIDI_TRACE_EVENT_MESSAGE_FIELD),
TraceLoggingWideString(MasterEndpointDefinition.EndpointName.c_str(), "name"),
TraceLoggingWideString(masterEndpointDefinition.EndpointName.c_str(), "name"),
TraceLoggingWideString(newDeviceInterfaceId, "endpoint id")
);

// TODO: Add to internal endpoint manager, and also return the device interface id
// todo: return new device interface id

return swdCreationResult; // TODO change this to account for other steps


// Add to internal endpoint manager
m_availableEndpointDefinitions.push_back(std::move(masterEndpointDefinition));

return swdCreationResult;
}
else
{
Expand All @@ -271,7 +276,7 @@ CMidi2KSAggregateMidiEndpointManager::CreateMidiUmpEndpoint(
TraceLoggingLevel(WINEVENT_LEVEL_ERROR),
TraceLoggingPointer(this, "this"),
TraceLoggingWideString(L"Aggregate UMP endpoint creation failed", MIDI_TRACE_EVENT_MESSAGE_FIELD),
TraceLoggingWideString(MasterEndpointDefinition.EndpointName.c_str(), "name"),
TraceLoggingWideString(masterEndpointDefinition.EndpointName.c_str(), "name"),
TraceLoggingHResult(swdCreationResult, MIDI_TRACE_EVENT_HRESULT_FIELD)
);

Expand Down Expand Up @@ -588,39 +593,34 @@ HRESULT CMidi2KSAggregateMidiEndpointManager::OnDeviceRemoved(DeviceWatcher, Dev
);


// TODO
// the interface is no longer active, search through our m_AvailableMidiPins to identify
// every entry with this filter interface id, and remove the SWD and remove the pin(s) from
// the m_AvailableMidiPins list.
do
{
auto item = std::find_if(m_availableEndpointDefinitions.begin(), m_availableEndpointDefinitions.end(), [&](const KsAggregateEndpointDefinition& endpointDefinition)
{
if (device.Id() == endpointDefinition.ParentDeviceInstanceId)
{
return true;
}

return false;
});


if (item == m_availableEndpointDefinitions.end())
{
break;
}

// the interface is no longer active, search through our m_AvailableMidiPins to identify
// every entry with this filter interface id, and remove the SWD and remove the pin(s) from
// the m_AvailableMidiPins list.
//do
//{
// auto item = std::find_if(m_AvailableMidiPins.begin(), m_AvailableMidiPins.end(), [&](const std::unique_ptr<MIDI_PIN_INFO>& Pin)
// {
// // if this interface id already activated, then we cannot activate/create a second time,
// if (device.Id() == Pin->Id)
// {
// return true;
// }

// return false;
// });

// if (item == m_AvailableMidiPins.end())
// {
// break;
// }

// // notify the device manager using the InstanceId for this midi device
// LOG_IF_FAILED(m_MidiDeviceManager->RemoveEndpoint(item->get()->InstanceId.c_str()));

// // remove the MIDI_PIN_INFO from the list
// m_AvailableMidiPins.erase(item);
//}
//while (TRUE);
// notify the device manager using the InstanceId for this midi device
LOG_IF_FAILED(m_MidiDeviceManager->RemoveEndpoint(device.Id().c_str()));

// remove the MIDI_PIN_INFO from the list
m_availableEndpointDefinitions.erase(item);
}
while (TRUE);

return S_OK;
}
Expand Down Expand Up @@ -669,8 +669,8 @@ CMidi2KSAggregateMidiEndpointManager::Cleanup()
m_DeviceStopped.revoke();
m_DeviceEnumerationCompleted.revoke();

m_MidiDeviceManager.reset();
m_MidiProtocolManager.reset();
m_MidiDeviceManager.reset();
m_MidiProtocolManager.reset();

return S_OK;
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ struct KsAggregateEndpointDefinition

MidiTransport TransportCapability;


std::vector<KsAggregateEndpointPinDefinition> Pins;
};

Expand Down Expand Up @@ -59,7 +58,7 @@ class CMidi2KSAggregateMidiEndpointManager :
wil::com_ptr_nothrow<IMidiDeviceManagerInterface> m_MidiDeviceManager;
wil::com_ptr_nothrow<IMidiEndpointProtocolManagerInterface> m_MidiProtocolManager;

//std::vector<std::unique_ptr<MIDI_PIN_INFO>> m_AvailableMidiPins;
std::vector<KsAggregateEndpointDefinition> m_availableEndpointDefinitions;

DeviceWatcher m_Watcher{0};
winrt::impl::consume_Windows_Devices_Enumeration_IDeviceWatcher<IDeviceWatcher>::Added_revoker m_DeviceAdded;
Expand Down

0 comments on commit ff8ef86

Please sign in to comment.