forked from nlpodyssey/safetensors
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dtype.go
88 lines (81 loc) · 1.91 KB
/
dtype.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
// Copyright 2024 Marc-Antoine Ruel. All rights reserved.
// Copyright 2023 The NLP Odyssey Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package safetensors
import (
"encoding/json"
"fmt"
)
// DType identifies a data type.
//
// It matches the DType type at
// https://github.com/huggingface/safetensors/blob/main/safetensors/src/tensor.rs.
type DType string
const (
// Boolan type
BOOL DType = "BOOL"
// Unsigned byte
U8 DType = "U8"
// Signed byte
I8 DType = "I8"
// FP8 <https://arxiv.org/pdf/2209.05433.pdf>
F8_E5M2 DType = "F8_E5M2"
// FP8 <https://arxiv.org/pdf/2209.05433.pdf>
F8_E4M3 DType = "F8_E4M3"
// Signed integer (16-bit)
I16 DType = "I16"
// Unsigned integer (16-bit)
U16 DType = "U16"
// Half-precision floating point
F16 DType = "F16"
// Brain floating point
BF16 DType = "BF16"
// Signed integer (32-bit)
I32 DType = "I32"
// Unsigned integer (32-bit)
U32 DType = "U32"
// Floating point (32-bit)
F32 DType = "F32"
// Floating point (64-bit)
F64 DType = "F64"
// Signed integer (64-bit)
I64 DType = "I64"
// Unsigned integer (64-bit)
U64 DType = "U64"
)
// DTypeToWordSize is the map of each DType and the number of bytes it
// represents.
var DTypeToWordSize = map[DType]uint64{
BOOL: 1,
U8: 1,
I8: 1,
F8_E5M2: 1,
F8_E4M3: 1,
I16: 2,
U16: 2,
F16: 2,
BF16: 2,
I32: 4,
U32: 4,
F32: 4,
F64: 8,
I64: 8,
U64: 8,
}
// WordSize returns the size in bytes of one element of this data type.
func (dt DType) WordSize() uint64 {
return DTypeToWordSize[dt]
}
// UnmarshalJSON implements json.Unmarshaler.
func (dt *DType) UnmarshalJSON(data []byte) error {
s := ""
if err := json.Unmarshal(data, &s); err != nil {
return err
}
if DTypeToWordSize[DType(s)] == 0 {
return fmt.Errorf("%q is not a valid DType", s)
}
*dt = DType(s)
return nil
}