Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Allow specifying the HTTP protocol version and version policy #2809

Open
wants to merge 6 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Net;
using System.Net.Http;
Expand Down Expand Up @@ -75,6 +76,18 @@ public HttpDocumentRetriever(HttpClient httpClient)
/// </summary>
public bool RequireHttps { get; set; } = true;

/// <summary>
/// If set specifies the protocol version to use when sending HTTP requests.
/// </summary>
public Version HttpVersion { get; set; }

#if NET6_0_OR_GREATER
/// <summary>
/// If set specifies the protocol version policy to use when sending HTTP requests.
/// </summary>
public HttpVersionPolicy? HttpVersionPolicy { get; set; }
#endif

/// <summary>
/// Returns a task which contains a string converted from remote document when completed, by using the provided address.
/// </summary>
Expand Down Expand Up @@ -132,6 +145,32 @@ public async Task<string> GetDocumentAsync(string address, CancellationToken can
throw LogHelper.LogExceptionMessage(unsuccessfulHttpResponseException);
}

/// <summary>
/// Applies the HTTP version and version policy to the <see cref="HttpRequestMessage"/>.
/// </summary>
/// <param name="httpClient">The <see cref="HttpClient"/> used to obtain the default values.</param>
/// <param name="message">The <see cref="HttpRequestMessage"/> where to apply the version and policy.</param>
[SuppressMessage("Usage", "CA1801:Review unused parameters", Justification = "Parameter is only used for .NET 6")]
private void ApplyHttpVersionAndPolicy(HttpClient httpClient, HttpRequestMessage message)
{
// either use explicit or default version from HttpClient
if (HttpVersion != null)
{
message.Version = HttpVersion;
}
else
{
#if NET6_0_OR_GREATER
message.Version = httpClient.DefaultRequestVersion;
#endif
}

#if NET6_0_OR_GREATER
// either use explicit or default version policy from HttpClient
message.VersionPolicy = HttpVersionPolicy.GetValueOrDefault(httpClient.DefaultVersionPolicy);
#endif
}

private async Task<HttpResponseMessage> SendAndRetryOnNetworkErrorAsync(HttpClient httpClient, Uri uri)
{
int maxAttempt = 2;
Expand All @@ -141,6 +180,8 @@ private async Task<HttpResponseMessage> SendAndRetryOnNetworkErrorAsync(HttpClie
// need to create a new message each time since you cannot send the same message twice
using (var message = new HttpRequestMessage(HttpMethod.Get, uri))
{
ApplyHttpVersionAndPolicy(httpClient, message);

if (SendAdditionalHeaderData)
IdentityModelTelemetryUtil.SetTelemetryData(message, AdditionalHeaderData);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.IO;
using System.Net;
using System.Net.Http;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -16,7 +17,7 @@
namespace Microsoft.IdentityModel.Protocols.Tests
{
/// <summary>
///
///
/// </summary>
public class HttpDocumentRetrieverTests
{
Expand Down Expand Up @@ -185,6 +186,92 @@ public static TheoryData<DocumentRetrieverTheoryData> GetMetadataTheoryData
return theoryData;
}
}

[Theory, MemberData(nameof(GetVersionTheoryData))]
public async Task HttpVersionTest(Version version)
{
var callback = new Func<HttpRequestMessage, CancellationToken, Task<HttpResponseMessage>>((msg, ct) =>
{
Assert.Equal(version, msg.Version);
return Task.FromResult(new HttpResponseMessage());
});

using var httpClient = new HttpClient(new DelegateHttpMessageHandler(callback));
var documentRetriever = new HttpDocumentRetriever(httpClient) { HttpVersion = version };
await documentRetriever.GetDocumentAsync("https://localhost", CancellationToken.None);
}

public static TheoryData<Version> GetVersionTheoryData
{
get
{
var theoryData = new TheoryData<Version>();
theoryData.Add(new Version(1,0));
theoryData.Add(new Version(1,1));
theoryData.Add(new Version(2,0));
return theoryData;
}
}

#if NET6_0_OR_GREATER
[Theory, MemberData(nameof(GetVersionTheoryData))]
public async Task HttpDefaultRequestVersionTest(Version version)
{
var callback = new Func<HttpRequestMessage, CancellationToken, Task<HttpResponseMessage>>((msg, ct) =>
{
Assert.Equal(version, msg.Version);
return Task.FromResult(new HttpResponseMessage());
});

using var httpClient = new HttpClient(new DelegateHttpMessageHandler(callback));
httpClient.DefaultRequestVersion = version;

var documentRetriever = new HttpDocumentRetriever(httpClient);
await documentRetriever.GetDocumentAsync("https://localhost", CancellationToken.None);
}

[Theory, MemberData(nameof(GetVersionPolicyTheoryData))]
public async Task HttpDefaultVersionPolicyTest(HttpVersionPolicy policy)
{
var callback = new Func<HttpRequestMessage, CancellationToken, Task<HttpResponseMessage>>((msg, ct) =>
{
Assert.Equal(policy, msg.VersionPolicy);
return Task.FromResult(new HttpResponseMessage());
});

using var httpClient = new HttpClient(new DelegateHttpMessageHandler(callback));
httpClient.DefaultVersionPolicy = policy;

var documentRetriever = new HttpDocumentRetriever(httpClient);
await documentRetriever.GetDocumentAsync("https://localhost", CancellationToken.None);
}

[Theory, MemberData(nameof(GetVersionPolicyTheoryData))]
public async Task HttpVersionPolicyTest(HttpVersionPolicy policy)
{
var callback = new Func<HttpRequestMessage, CancellationToken, Task<HttpResponseMessage>>((msg, ct) =>
{
Assert.Equal(policy, msg.VersionPolicy);
return Task.FromResult(new HttpResponseMessage());
});

using var httpClient = new HttpClient(new DelegateHttpMessageHandler(callback));
var documentRetriever = new HttpDocumentRetriever(httpClient) { HttpVersionPolicy = policy };
await documentRetriever.GetDocumentAsync("https://localhost", CancellationToken.None);
}

public static TheoryData<HttpVersionPolicy> GetVersionPolicyTheoryData
{
get
{
var theoryData = new TheoryData<HttpVersionPolicy>();
theoryData.Add(HttpVersionPolicy.RequestVersionOrLower);
theoryData.Add(HttpVersionPolicy.RequestVersionOrHigher);
theoryData.Add(HttpVersionPolicy.RequestVersionExact);
return theoryData;
}
}
#endif
}

public class DocumentRetrieverTheoryData : TheoryDataBase
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.IdentityModel.TestUtils
{
/// <summary>
/// A <see cref="HttpMessageHandler"/> which delegates sending the request to a callback.
/// </summary>
public class DelegateHttpMessageHandler : HttpMessageHandler
{
private readonly Func<HttpRequestMessage, CancellationToken, Task<HttpResponseMessage>> _callback;

/// <summary>
/// Initializes a new instance of the <see cref="DelegateHttpMessageHandler"/>.
/// </summary>
/// <param name="callback">The callback to invoke when HTTP request is being executed.</param>
public DelegateHttpMessageHandler(Func<HttpRequestMessage, CancellationToken, Task<HttpResponseMessage>> callback)
{
_callback = callback;
}

/// <inheritdoc />
protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
return await _callback(request, cancellationToken).ConfigureAwait(false);
}
}
}