Skip to content

Commit

Permalink
Use URLSession.dataTask() with delegate instead of URLSession.bytes() (
Browse files Browse the repository at this point in the history
…#282)

* Use URLSession.dataTask() with delegate instead of URLSession.bytes()

* Use URLSession.shared instead of creating a new one each time
  • Loading branch information
edigaryev authored Oct 18, 2022
1 parent af7530e commit 39e1b84
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 19 deletions.
46 changes: 46 additions & 0 deletions Sources/tart/Fetcher.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import Foundation
import AsyncAlgorithms

class Fetcher: NSObject, URLSessionTaskDelegate, URLSessionDelegate, URLSessionDataDelegate {
let responseCh = AsyncThrowingChannel<URLResponse, Error>()
let dataCh = AsyncThrowingChannel<Data, Error>()

func fetch(_ request: URLRequest) async throws -> (AsyncThrowingChannel<Data, Error>, URLResponse) {
let task = URLSession.shared.dataTask(with: request)
task.delegate = self
task.resume()

// Wait for the response and only then return
var iter = responseCh.makeAsyncIterator()
let response = try await iter.next()!

return (dataCh, response)
}

func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive response: URLResponse) async -> URLSession.ResponseDisposition {
await responseCh.send(response)

return .allow
}

func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive data: Data) {
let sema = DispatchSemaphore(value: 0)

Task {
await dataCh.send(data)
sema.signal()
}

sema.wait()
}

func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) {
if let error = error {
// Premature termination
responseCh.fail(error)
dataCh.fail(error)
} else {
dataCh.finish()
}
}
}
32 changes: 15 additions & 17 deletions Sources/tart/OCI/Registry.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ import Foundation
import Algorithms
import AsyncAlgorithms

let chunkSizeBytes = 1 * 1024 * 1024

enum RegistryError: Error {
case UnexpectedHTTPStatusCode(when: String, code: Int, details: String = "")
case MissingLocationHeader
Expand Down Expand Up @@ -31,11 +29,11 @@ extension Data {
}
}

extension URLSession.AsyncBytes {
extension AsyncThrowingChannel<Data, Error> {
func asData() async throws -> Data {
var result = Data()

for try await chunk in chunks(ofCount: chunkSizeBytes) {
for try await chunk in self {
result += chunk
}

Expand Down Expand Up @@ -228,14 +226,14 @@ class Registry {
}

public func pullBlob(_ digest: String, handler: (Data) throws -> Void) async throws {
let (bytes, response) = try await bytesRequest(.GET, endpointURL("\(namespace)/blobs/\(digest)"))
let (channel, response) = try await channelRequest(.GET, endpointURL("\(namespace)/blobs/\(digest)"))
if response.statusCode != HTTPCode.Ok.rawValue {
let body = try await bytes.asData().asText()
let body = try await channel.asData().asText()
throw RegistryError.UnexpectedHTTPStatusCode(when: "pulling blob", code: response.statusCode,
details: body)
}

for try await part in bytes.chunks(ofCount: chunkSizeBytes) {
for try await part in channel {
try Task.checkCancellation()

try handler(Data(part))
Expand All @@ -256,20 +254,20 @@ class Registry {
body: Data? = nil,
doAuth: Bool = true
) async throws -> (Data, HTTPURLResponse) {
let (bytes, response) = try await bytesRequest(method, urlComponents,
let (channel, response) = try await channelRequest(method, urlComponents,
headers: headers, parameters: parameters, body: body, doAuth: doAuth)

return (try await bytes.asData(), response)
return (try await channel.asData(), response)
}

private func bytesRequest(
private func channelRequest(
_ method: HTTPMethod,
_ urlComponents: URLComponents,
headers: Dictionary<String, String> = Dictionary(),
parameters: Dictionary<String, String> = Dictionary(),
body: Data? = nil,
doAuth: Bool = true
) async throws -> (URLSession.AsyncBytes, HTTPURLResponse) {
) async throws -> (AsyncThrowingChannel<Data, Error>, HTTPURLResponse) {
var urlComponents = urlComponents

if urlComponents.queryItems == nil && !parameters.isEmpty {
Expand All @@ -294,14 +292,14 @@ class Registry {
currentAuthToken = nil
}

var (bytes, response) = try await authAwareRequest(request: request)
var (channel, response) = try await authAwareRequest(request: request)

if doAuth && response.statusCode == HTTPCode.Unauthorized.rawValue {
try await auth(response: response)
(bytes, response) = try await authAwareRequest(request: request)
(channel, response) = try await authAwareRequest(request: request)
}

return (bytes, response)
return (channel, response)
}

private func auth(response: HTTPURLResponse) async throws {
Expand Down Expand Up @@ -373,16 +371,16 @@ class Registry {
return nil
}

private func authAwareRequest(request: URLRequest) async throws -> (URLSession.AsyncBytes, HTTPURLResponse) {
private func authAwareRequest(request: URLRequest) async throws -> (AsyncThrowingChannel<Data, Error>, HTTPURLResponse) {
var request = request

if let token = currentAuthToken {
let (name, value) = token.header()
request.addValue(value, forHTTPHeaderField: name)
}

let (bytes, response) = try await URLSession.shared.bytes(for: request)
let (channel, response) = try await Fetcher().fetch(request)

return (bytes, response as! HTTPURLResponse)
return (channel, response as! HTTPURLResponse)
}
}
4 changes: 2 additions & 2 deletions Sources/tart/VM.swift
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class VM: NSObject, VZVirtualMachineDelegate, ObservableObject {

static func retrieveIPSW(remoteURL: URL) async throws -> URL {
// Check if we already have this IPSW in cache
let (bytes, response) = try await URLSession.shared.bytes(from: remoteURL)
let (channel, response) = try await Fetcher().fetch(URLRequest(url: remoteURL))

if let hash = (response as! HTTPURLResponse).value(forHTTPHeaderField: "x-amz-meta-digest-sha256") {
let ipswLocation = try IPSWCache().locationFor(fileName: "sha256:\(hash).ipsw")
Expand All @@ -90,7 +90,7 @@ class VM: NSObject, VZVirtualMachineDelegate, ObservableObject {
let fileHandle = try FileHandle(forWritingTo: temporaryLocation)
let digest = Digest()

for try await chunk in bytes.chunks(ofCount: chunkSizeBytes) {
for try await chunk in channel {
let chunkAsData = Data(chunk)
fileHandle.write(chunkAsData)
digest.update(chunkAsData)
Expand Down

0 comments on commit 39e1b84

Please sign in to comment.