Skip to content

Commit

Permalink
Add sparsity to the small sequences graph
Browse files Browse the repository at this point in the history
  • Loading branch information
philipturner committed Jul 27, 2023
1 parent aa58e2c commit 0951011
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 36 deletions.
Binary file modified CI/float16-small-causal-latest.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions MetalFlashAttention.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@
984279D62A619008001BBD55 /* AttentionTest.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AttentionTest.swift; sourceTree = "<group>"; };
984F721A2A6EEB0E00C15D4A /* float16-small-sequences-latest.png */ = {isa = PBXFileReference; lastKnownFileType = image.png; path = "float16-small-sequences-latest.png"; sourceTree = "<group>"; };
984F721B2A6EEC4B00C15D4A /* float16-large-sequences-latest.png */ = {isa = PBXFileReference; lastKnownFileType = image.png; path = "float16-large-sequences-latest.png"; sourceTree = "<group>"; };
984F721C2A6EEE0F00C15D4A /* float16-small-causal-latest.png */ = {isa = PBXFileReference; lastKnownFileType = image.png; path = "float16-small-causal-latest.png"; sourceTree = "<group>"; };
984F721D2A6EF09000C15D4A /* float16-head-sizes-latest.png */ = {isa = PBXFileReference; lastKnownFileType = image.png; path = "float16-head-sizes-latest.png"; sourceTree = "<group>"; };
987E35DB2A45E4F400ACACE3 /* MetalFlashAttention */ = {isa = PBXFileReference; explicitFileType = "compiled.mach-o.executable"; includeInIndex = 0; path = MetalFlashAttention; sourceTree = BUILT_PRODUCTS_DIR; };
987E35DE2A45E4F400ACACE3 /* main.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = main.swift; sourceTree = "<group>"; };
Expand All @@ -104,6 +103,7 @@
98C795312A4DC1F200DB688D /* GEMMSquareBenchmark.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = GEMMSquareBenchmark.swift; sourceTree = "<group>"; };
98DFBD0C2A72F0EC002E4B47 /* float16-large-causal-latest.png */ = {isa = PBXFileReference; lastKnownFileType = image.png; path = "float16-large-causal-latest.png"; sourceTree = "<group>"; };
98DFBD0D2A72F242002E4B47 /* float32-large-causal-latest.png */ = {isa = PBXFileReference; lastKnownFileType = image.png; path = "float32-large-causal-latest.png"; sourceTree = "<group>"; };
98DFBD0E2A72F92C002E4B47 /* float16-small-causal-latest.png */ = {isa = PBXFileReference; lastKnownFileType = image.png; path = "float16-small-causal-latest.png"; sourceTree = "<group>"; };
98F2F5DC2A60978C006216F4 /* GEMMTransposeTest.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = GEMMTransposeTest.swift; sourceTree = "<group>"; };
98F7440E2A4A008C00B5E60A /* build.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = build.swift; sourceTree = "<group>"; };
98F7440F2A4A0CB200B5E60A /* API.md */ = {isa = PBXFileReference; lastKnownFileType = net.daringfireball.markdown; path = API.md; sourceTree = "<group>"; };
Expand Down Expand Up @@ -134,7 +134,7 @@
984F721A2A6EEB0E00C15D4A /* float16-small-sequences-latest.png */,
984F721B2A6EEC4B00C15D4A /* float16-large-sequences-latest.png */,
98DFBD0D2A72F242002E4B47 /* float32-large-causal-latest.png */,
984F721C2A6EEE0F00C15D4A /* float16-small-causal-latest.png */,
98DFBD0E2A72F92C002E4B47 /* float16-small-causal-latest.png */,
98DFBD0C2A72F0EC002E4B47 /* float16-large-causal-latest.png */,
984F721D2A6EF09000C15D4A /* float16-head-sizes-latest.png */,
98FDDF002A5895CE0096BC27 /* float16-nt-batched-latest.png */,
Expand Down
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,16 +189,16 @@ Dense: Stable Diffusion 2 outermost attention layer @ 512x512 (sequence length =

![FlashAttention (F16, H=5, D=64)](./CI/float16-large-sequences-latest.png)

### Float32 Sequence Scaling (Causal Mask)

![FlashAttention (F32, H=10, D=64)](./CI/float32-large-causal-latest.png)

### Float16 Sequence Scaling (Causal Mask)

![FlashAttention (F16, H=10, D=64)](./CI/float16-small-causal-latest.png)

![FlashAttention (F16, H=10, D=64)](./CI/float16-large-causal-latest.png)

### Float32 Sequence Scaling (Causal Mask)

![FlashAttention (F32, H=10, D=64)](./CI/float32-large-causal-latest.png)

### Float16 Head Scaling

Dense: Stable Diffusion 1 outermost attention layer @ 512x512 (head size = 40)
Expand Down
59 changes: 30 additions & 29 deletions Tests/Test Cases/AttentionPerfTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,24 @@ class AttentionPerfTests: MFATestCase {
// For heads scaling:
// sequence length 4096

let duration = Duration(granularity: -1, length: 2)
let duration = Duration(granularity: 2, length: 2)
let (domain, ranges) = rangeSequenceScaling(
duration: duration, type: .causal)

var backends = SequenceType.causal.backends
// let backends: [AttentionBackend] = [.mfa]

backends = backends.compactMap {
if $0.isMPS { return nil }
return $0
}
// backends = backends.compactMap {
// if $0.isMPS { return nil }
// return $0
// }

// let duration = Duration(granularity: 1, length: 1)
// let (domain, ranges) = rangeHeadScaling(duration: duration)
// let backends = [AttentionBackend.mps, AttentionBackend.mfa]
testAttention(
domain: domain, ranges: ranges, backends: backends, config: .triangular)
domain: domain, ranges: ranges, backends: backends,
config: .none)
}

enum GraphConfig {
Expand Down Expand Up @@ -277,33 +278,33 @@ class AttentionPerfTests: MFATestCase {
var domain: ClosedRange<Int>
var parameters: [SIMD8<Int>]
if type == .causal {
domain = 0...16384
// domain = 0...16384
// domain = 512...1024
// domain = 0...1024
domain = 0...1024
parameters = [
SIMD8( 1, 8, 256, 1, 0, 0, 0, 0),
SIMD8( 8, 192, 256, 8, 0, 0, 0, 0),
SIMD8( 192, 256, 128, 8, 0, 0, 0, 0),
SIMD8( 256, 384, 64, 8, 0, 0, 0, 0),
SIMD8( 384, 512, 32, 8, 0, 0, 0, 0),
SIMD8( 512, 768, 16, 16, 0, 0, 0, 0),
SIMD8( 768, 1024, 8, 16, 0, 0, 0, 0),
SIMD8(1024, 1536, 4, 32, 0, 0, 0, 0),
SIMD8(1536, 2048, 2, 32, 8, 0, 0, 0),
SIMD8(2048, 3072, 2, 64, 8, 0, 0, 0),
SIMD8(3072, 4096, 2, 128, 8, 0, 0, 0),
SIMD8(4096, 6144, 2, 256, 8, 0, 0, 0),
SIMD8( 6 * 1024, 8 * 1024, 2, 512, 7, 0, 0, 0),
SIMD8( 8 * 1024, 12 * 1024, 2, 1024, 6, 0, 0, 0),
SIMD8(12 * 1024, 16 * 1024 + 1, 2, 2048, 5, 0, 0, 0),
// SIMD8( 1, 8, 256, 1, 0, 0, 0, 0),
// SIMD8( 8, 192, 256, 8, 0, 0, 0, 0),
// SIMD8( 192, 256, 128, 8, 0, 0, 0, 0),
// SIMD8( 256, 384, 64, 8, 0, 0, 0, 0),
// SIMD8( 384, 512, 32, 8, 0, 0, 0, 0),
// SIMD8( 512, 768, 16, 16, 0, 0, 0, 0),
// SIMD8( 768, 1024, 8, 16, 0, 0, 0, 0),
// SIMD8(1024, 1536, 4, 32, 0, 0, 0, 0),
// SIMD8(1536, 2048, 2, 32, 8, 0, 0, 0),
// SIMD8(2048, 3072, 2, 64, 8, 0, 0, 0),
// SIMD8(3072, 4096, 2, 128, 8, 0, 0, 0),
// SIMD8(4096, 6144, 2, 256, 8, 0, 0, 0),
// SIMD8( 6 * 1024, 8 * 1024, 2, 512, 7, 0, 0, 0),
// SIMD8( 8 * 1024, 12 * 1024, 2, 1024, 6, 0, 0, 0),
// SIMD8(12 * 1024, 16 * 1024 + 1, 2, 2048, 5, 0, 0, 0),

// SIMD4(granularity, 192, 256, granularity, 0, 0, 0, 0),
// SIMD4( 192, 256, 128, granularity, 0, 0, 0, 0),
// SIMD4( 256, 384, 64, granularity, 0, 0, 0, 0),
// SIMD4( 384, 512, 32, granularity, 0, 0, 0, 0),
SIMD8(granularity, 192, 256, granularity, 0, 0, 0, 0),
SIMD8( 192, 256, 128, granularity, 0, 0, 0, 0),
SIMD8( 256, 384, 64, granularity, 0, 0, 0, 0),
SIMD8( 384, 512, 32, granularity, 0, 0, 0, 0),

// SIMD4( 512, 768, 16, granularity, 0, 0, 0, 0),
// SIMD4( 768, 1025, 8, granularity, 0, 0, 0, 0),
SIMD8( 512, 768, 16, granularity, 0, 0, 0, 0),
SIMD8( 768, 1025, 8, granularity, 0, 0, 0, 0),
]
} else if type == .small {
domain = 0...2048
Expand Down
2 changes: 1 addition & 1 deletion Tests/Test Cases/MFATestCase.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import Foundation
class MFATestCase {
// Global setting for the precision used in tests.
#if arch(arm64)
typealias Real = Float32
typealias Real = Float16
#else
typealias Real = Float
#endif
Expand Down

0 comments on commit 0951011

Please sign in to comment.