Skip to content

Commit

Permalink
add bit_count and get_bit functions
Browse files Browse the repository at this point in the history
  • Loading branch information
ParkMyCar committed Jan 17, 2025
1 parent 420f0ff commit 91970a2
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/expr/src/scalar.proto
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ message ProtoUnaryFunc {
ProtoToCharTimestamp to_char_timestamp = 331;
ProtoToCharTimestamp to_char_timestamp_tz = 332;
google.protobuf.Empty cast_date_to_mz_timestamp = 333;
google.protobuf.Empty bit_count_bytes = 334;
}
}

Expand Down Expand Up @@ -668,6 +669,7 @@ message ProtoBinaryFunc {
bool list_contains_list = 193;
bool array_contains_array = 194;
google.protobuf.Empty starts_with = 195;
google.protobuf.Empty get_bit = 196;
}
}

Expand Down
35 changes: 34 additions & 1 deletion src/expr/src/scalar/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1361,6 +1361,26 @@ fn power_numeric<'a>(a: Datum<'a>, b: Datum<'a>) -> Result<Datum<'a>, EvalError>
}
}

fn get_bit<'a>(a: Datum<'a>, b: Datum<'a>) -> Result<Datum<'a>, EvalError> {
let bytes = a.unwrap_bytes();
let index = b.unwrap_int32();
let err = EvalError::IndexOutOfRange {
provided: index,
valid_end: i32::try_from(bytes.len().saturating_mul(8)).unwrap() - 1,
};

let index = usize::try_from(index).map_err(|_| err.clone())?;

let byte_index = index / 8;
let bit_index = index % 8;

let i = bytes
.get(byte_index)
.map(|b| (*b >> bit_index) & 1)
.ok_or(err)?;
Ok(Datum::from(i32::from(i)))
}

fn get_byte<'a>(a: Datum<'a>, b: Datum<'a>) -> Result<Datum<'a>, EvalError> {
let bytes = a.unwrap_bytes();
let index = b.unwrap_int32();
Expand Down Expand Up @@ -2344,6 +2364,7 @@ pub enum BinaryFunc {
LogNumeric,
Power,
PowerNumeric,
GetBit,
GetByte,
ConstantTimeEqBytes,
ConstantTimeEqString,
Expand Down Expand Up @@ -2607,6 +2628,7 @@ impl BinaryFunc {
BinaryFunc::Power => power(a, b),
BinaryFunc::PowerNumeric => power_numeric(a, b),
BinaryFunc::RepeatString => repeat_string(a, b, temp_storage),
BinaryFunc::GetBit => get_bit(a, b),
BinaryFunc::GetByte => get_byte(a, b),
BinaryFunc::ConstantTimeEqBytes => constant_time_eq_bytes(a, b),
BinaryFunc::ConstantTimeEqString => constant_time_eq_string(a, b),
Expand Down Expand Up @@ -2804,6 +2826,7 @@ impl BinaryFunc {
ScalarType::Numeric { max_scale: None }.nullable(in_nullable)
}

GetBit => ScalarType::Int32.nullable(in_nullable),
GetByte => ScalarType::Int32.nullable(in_nullable),

ConstantTimeEqBytes | ConstantTimeEqString => {
Expand Down Expand Up @@ -3023,6 +3046,7 @@ impl BinaryFunc {
| LogNumeric
| Power
| PowerNumeric
| GetBit
| GetByte
| RangeContainsElem { .. }
| RangeContainsRange { .. }
Expand Down Expand Up @@ -3241,6 +3265,7 @@ impl BinaryFunc {
| ListRemove
| LikeEscape
| UuidGenerateV5
| GetBit
| GetByte
| MzAclItemContainsPrivilege
| ConstantTimeEqBytes
Expand Down Expand Up @@ -3508,7 +3533,8 @@ impl BinaryFunc {
| BinaryFunc::Decode => (false, false),
// TODO: it may be safe to treat these as monotone.
BinaryFunc::LogNumeric | BinaryFunc::Power | BinaryFunc::PowerNumeric => (false, false),
BinaryFunc::GetByte
BinaryFunc::GetBit
| BinaryFunc::GetByte
| BinaryFunc::RangeContainsElem { .. }
| BinaryFunc::RangeContainsRange { .. }
| BinaryFunc::RangeOverlaps
Expand Down Expand Up @@ -3716,6 +3742,7 @@ impl fmt::Display for BinaryFunc {
BinaryFunc::Power => f.write_str("power"),
BinaryFunc::PowerNumeric => f.write_str("power_numeric"),
BinaryFunc::RepeatString => f.write_str("repeat"),
BinaryFunc::GetBit => f.write_str("get_bit"),
BinaryFunc::GetByte => f.write_str("get_byte"),
BinaryFunc::ConstantTimeEqBytes => f.write_str("constant_time_compare_bytes"),
BinaryFunc::ConstantTimeEqString => f.write_str("constant_time_compare_strings"),
Expand Down Expand Up @@ -4140,6 +4167,7 @@ impl RustType<ProtoBinaryFunc> for BinaryFunc {
BinaryFunc::LogNumeric => LogNumeric(()),
BinaryFunc::Power => Power(()),
BinaryFunc::PowerNumeric => PowerNumeric(()),
BinaryFunc::GetBit => GetBit(()),
BinaryFunc::GetByte => GetByte(()),
BinaryFunc::RangeContainsElem { elem_type, rev } => {
RangeContainsElem(crate::scalar::proto_binary_func::ProtoRangeContainsInner {
Expand Down Expand Up @@ -4360,6 +4388,7 @@ impl RustType<ProtoBinaryFunc> for BinaryFunc {
LogNumeric(()) => Ok(BinaryFunc::LogNumeric),
Power(()) => Ok(BinaryFunc::Power),
PowerNumeric(()) => Ok(BinaryFunc::PowerNumeric),
GetBit(()) => Ok(BinaryFunc::GetBit),
GetByte(()) => Ok(BinaryFunc::GetByte),
RangeContainsElem(inner) => Ok(BinaryFunc::RangeContainsElem {
elem_type: inner
Expand Down Expand Up @@ -4799,6 +4828,7 @@ derive_unary!(
FloorFloat64,
FloorNumeric,
Ascii,
BitCountBytes,
BitLengthBytes,
BitLengthString,
ByteLengthBytes,
Expand Down Expand Up @@ -5209,6 +5239,7 @@ impl Arbitrary for UnaryFunc {
FloorFloat64::arbitrary().prop_map_into().boxed(),
FloorNumeric::arbitrary().prop_map_into().boxed(),
Ascii::arbitrary().prop_map_into().boxed(),
BitCountBytes::arbitrary().prop_map_into().boxed(),
BitLengthBytes::arbitrary().prop_map_into().boxed(),
BitLengthString::arbitrary().prop_map_into().boxed(),
ByteLengthBytes::arbitrary().prop_map_into().boxed(),
Expand Down Expand Up @@ -5597,6 +5628,7 @@ impl RustType<ProtoUnaryFunc> for UnaryFunc {
UnaryFunc::FloorFloat64(_) => FloorFloat64(()),
UnaryFunc::FloorNumeric(_) => FloorNumeric(()),
UnaryFunc::Ascii(_) => Ascii(()),
UnaryFunc::BitCountBytes(_) => BitCountBytes(()),
UnaryFunc::BitLengthBytes(_) => BitLengthBytes(()),
UnaryFunc::BitLengthString(_) => BitLengthString(()),
UnaryFunc::ByteLengthBytes(_) => ByteLengthBytes(()),
Expand Down Expand Up @@ -6071,6 +6103,7 @@ impl RustType<ProtoUnaryFunc> for UnaryFunc {
FloorFloat64(_) => Ok(impls::FloorFloat64.into()),
FloorNumeric(_) => Ok(impls::FloorNumeric.into()),
Ascii(_) => Ok(impls::Ascii.into()),
BitCountBytes(_) => Ok(impls::BitCountBytes.into()),
BitLengthBytes(_) => Ok(impls::BitLengthBytes.into()),
BitLengthString(_) => Ok(impls::BitLengthString.into()),
ByteLengthBytes(_) => Ok(impls::ByteLengthBytes.into()),
Expand Down
9 changes: 9 additions & 0 deletions src/expr/src/scalar/func/impls/byte.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0.

use mz_ore::cast::CastFrom;
use mz_repr::strconv;

use crate::EvalError;
Expand Down Expand Up @@ -64,6 +65,14 @@ sqlfunc!(
}
);

sqlfunc!(
#[sqlname = "bit_count"]
fn bit_count_bytes<'a>(a: &'a [u8]) -> Result<i64, EvalError> {
let count: u64 = a.iter().map(|b| u64::cast_from(b.count_ones())).sum();
i64::try_from(count).or(Err(EvalError::Int64OutOfRange(count.to_string().into())))
}
);

sqlfunc!(
#[sqlname = "bit_length"]
fn bit_length_bytes<'a>(a: &'a [u8]) -> Result<i32, EvalError> {
Expand Down
6 changes: 6 additions & 0 deletions src/sql/src/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1866,6 +1866,9 @@ pub static PG_CATALOG_BUILTINS: LazyLock<BTreeMap<&'static str, Func>> = LazyLoc
params!(Float64) => Operation::nullary(|_ecx| catalog_name_only!("avg")) => Float64, 2105;
params!(Interval) => Operation::nullary(|_ecx| catalog_name_only!("avg")) => Interval, 2106;
},
"bit_count" => Scalar {
params!(Bytes) => UnaryFunc::BitCountBytes(func::BitCountBytes) => Int64, 6163;
},
"bit_length" => Scalar {
params!(Bytes) => UnaryFunc::BitLengthBytes(func::BitLengthBytes) => Int32, 1810;
params!(String) => UnaryFunc::BitLengthString(func::BitLengthString) => Int32, 1811;
Expand Down Expand Up @@ -2085,6 +2088,9 @@ pub static PG_CATALOG_BUILTINS: LazyLock<BTreeMap<&'static str, Func>> = LazyLoc
END"
) => String, 1081;
},
"get_bit" => Scalar {
params!(Bytes, Int32) => BinaryFunc::GetBit => Int32, 723;
},
"get_byte" => Scalar {
params!(Bytes, Int32) => BinaryFunc::GetByte => Int32, 721;
},
Expand Down
113 changes: 113 additions & 0 deletions test/sqllogictest/bytea.slt
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,119 @@ SELECT bit_length('DEADBEEF'::text);
----
64

query I
SELECT bit_count('\x1234567890'::bytea);
----
15

query I
SELECT bit_count('\x00'::bytea);
----
0

query I
SELECT bit_count('\x0F'::bytea);
----
4

query I
SELECT bit_count('\xFF'::bytea);
----
8

query I
SELECT bit_count('\xF0FF'::bytea);
----
12

query I
SELECT get_byte('\x1234567890'::bytea, 4);
----
144

query I
SELECT get_bit('\x1234567890'::bytea, 30);
----
1

query II
SELECT n, get_bit('\x1234567890'::bytea, n) FROM generate_series(0, 39) as n ORDER BY n DESC;
----
39 1
38 0
37 0
36 1
35 0
34 0
33 0
32 0
31 0
30 1
29 1
28 1
27 1
26 0
25 0
24 0
23 0
22 1
21 0
20 1
19 0
18 1
17 1
16 0
15 0
14 0
13 1
12 1
11 0
10 1
9 0
8 0
7 0
6 0
5 0
4 1
3 0
2 0
1 1
0 0


query I
SELECT get_bit('\xF00a'::bytea, 13);
----
0

query I
SELECT get_bit('\xF00a'::bytea, 5);
----
1

query II
SELECT n, get_bit('\xF00a'::bytea, n) FROM generate_series(0, 15) as n ORDER BY n DESC;
----
15 0
14 0
13 0
12 0
11 1
10 0
9 1
8 0
7 1
6 1
5 1
4 1
3 0
2 0
1 0
0 0

statement error index 16 out of valid range, 0..15
SELECT get_bit('\xF00a'::bytea, 16);

statement error
SELECT length('deadbeef'::text, 'utf-8')

Expand Down

0 comments on commit 91970a2

Please sign in to comment.