Skip to content

Commit

Permalink
fix(mssql): ensure ibis.random() generates a new value per call (#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist authored Sep 19, 2024
1 parent 82e9ba0 commit 1667f43
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 4 deletions.
6 changes: 6 additions & 0 deletions ibis/backends/sql/compilers/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ def to_sqlglot(
def visit_RandomUUID(self, op, **_):
return self.f.newid()

def visit_RandomScalar(self, op, **_):
# By default RAND() will generate the same value for all calls within a
# query. The standard way to work around this is to pass in a unique
# value per call, which `CHECKSUM(NEWID())` provides.
return self.f.rand(self.f.checksum(self.f.newid()))

def visit_StringLength(self, op, *, arg):
"""The MSSQL LEN function doesn't count trailing spaces.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
SELECT
TOP 10
NTILE(2) OVER (ORDER BY RAND() ASC) - 1 AS [new_col]
NTILE(2) OVER (ORDER BY RAND(CHECKSUM(NEWID())) ASC) - 1 AS [new_col]
FROM [test] AS [t0]
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ SELECT
FROM (
SELECT
[t0].[x],
RAND() AS [y],
RAND() AS [z]
RAND(CHECKSUM(NEWID())) AS [y],
RAND(CHECKSUM(NEWID())) AS [z]
FROM [t] AS [t0]
) AS [t1]
2 changes: 1 addition & 1 deletion ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ def test_order_by(backend, alltypes, df, key, df_kwargs):
backend.assert_frame_equal(result, expected)


@pytest.mark.notimpl(["polars", "mssql", "druid"])
@pytest.mark.notimpl(["polars", "druid"])
@pytest.mark.notimpl(
["risingwave"],
raises=PsycoPg2InternalError,
Expand Down
12 changes: 12 additions & 0 deletions ibis/backends/tests/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1309,6 +1309,18 @@ def test_random(con):
assert 0 <= result <= 1


@pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError)
@pytest.mark.notimpl(
["risingwave"],
raises=PsycoPg2InternalError,
reason="function random() does not exist",
)
def test_random_different_per_row(alltypes):
result = alltypes.select("int_col", rand_col=ibis.random()).execute()
assert result.rand_col.nunique() > 1


@pytest.mark.parametrize(
("ibis_func", "pandas_func"),
[
Expand Down

0 comments on commit 1667f43

Please sign in to comment.