Skip to content

Commit

Permalink
Merge pull request #727 from DenisFrezzato/improve-refinement-type
Browse files Browse the repository at this point in the history
Improve refine with type guards
  • Loading branch information
colinhacks authored Oct 26, 2021
2 parents a774526 + a046881 commit fdd7084
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 8 deletions.
33 changes: 33 additions & 0 deletions deno/lib/__tests__/refine.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import { expect } from "https://deno.land/x/[email protected]/mod.ts";
const test = Deno.test;

import { util } from "../helpers/util.ts";
import * as z from "../index.ts";
import { ZodIssueCode } from "../ZodError.ts";

Expand Down Expand Up @@ -49,6 +50,38 @@ test("refinement 2", () => {
).toThrow();
});

test("refinement type guard", () => {
const validationSchema = z.object({
a: z.string().refine((s): s is "a" => s === "a"),
});
type Schema = z.infer<typeof validationSchema>;

const f1: util.AssertEqual<"a", Schema["a"]> = true;
f1;
const f2: util.AssertEqual<"string", Schema["a"]> = false;
f2;
});

test("refinement Promise", async () => {
const validationSchema = z
.object({
email: z.string().email(),
password: z.string(),
confirmPassword: z.string(),
})
.refine(
(data) =>
Promise.resolve().then(() => data.password === data.confirmPassword),
"Both password and confirmation must match"
);

await validationSchema.parseAsync({
email: "[email protected]",
password: "password",
confirmPassword: "password",
});
});

test("custom path", async () => {
const result = await z
.object({
Expand Down
24 changes: 20 additions & 4 deletions deno/lib/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,16 @@ export abstract class ZodType<
/** The .check method has been removed in Zod 3. For details see https://github.com/colinhacks/zod/tree/v3. */
check!: never;

refine<Func extends (arg: Output) => any>(
check: Func,
refine<RefinedOutput extends Output>(
check: (arg: Output) => arg is RefinedOutput,
message?: string | CustomErrorParams | ((arg: Output) => CustomErrorParams)
): ZodEffects<this, RefinedOutput, RefinedOutput>;
refine(
check: (arg: Output) => unknown | Promise<unknown>,
message?: string | CustomErrorParams | ((arg: Output) => CustomErrorParams)
): ZodEffects<this, Output, Input>;
refine(
check: (arg: Output) => unknown,
message?: string | CustomErrorParams | ((arg: Output) => CustomErrorParams)
): ZodEffects<this, Output, Input> {
const getIssueProperties: any = (val: Output) => {
Expand Down Expand Up @@ -264,8 +272,16 @@ export abstract class ZodType<
});
}

refinement<RefinedOutput extends Output>(
check: (arg: Output) => arg is RefinedOutput,
refinementData: IssueData | ((arg: Output, ctx: RefinementCtx) => IssueData)
): ZodEffects<this, RefinedOutput, RefinedOutput>;
refinement(
check: (arg: Output) => any,
check: (arg: Output) => boolean,
refinementData: IssueData | ((arg: Output, ctx: RefinementCtx) => IssueData)
): ZodEffects<this, Output, Input>;
refinement(
check: (arg: Output) => unknown,
refinementData: IssueData | ((arg: Output, ctx: RefinementCtx) => IssueData)
): ZodEffects<this, Output, Input> {
return this._refinement((val, ctx) => {
Expand All @@ -289,7 +305,7 @@ export abstract class ZodType<
schema: this,
typeName: ZodFirstPartyTypeKind.ZodEffects,
effect: { type: "refinement", refinement },
}) as any;
});
}
superRefine = this._refinement;

Expand Down
33 changes: 33 additions & 0 deletions src/__tests__/refine.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// @ts-ignore TS6133
import { expect, test } from "@jest/globals";

import { util } from "../helpers/util";
import * as z from "../index";
import { ZodIssueCode } from "../ZodError";

Expand Down Expand Up @@ -48,6 +49,38 @@ test("refinement 2", () => {
).toThrow();
});

test("refinement type guard", () => {
const validationSchema = z.object({
a: z.string().refine((s): s is "a" => s === "a"),
});
type Schema = z.infer<typeof validationSchema>;

const f1: util.AssertEqual<"a", Schema["a"]> = true;
f1;
const f2: util.AssertEqual<"string", Schema["a"]> = false;
f2;
});

test("refinement Promise", async () => {
const validationSchema = z
.object({
email: z.string().email(),
password: z.string(),
confirmPassword: z.string(),
})
.refine(
(data) =>
Promise.resolve().then(() => data.password === data.confirmPassword),
"Both password and confirmation must match"
);

await validationSchema.parseAsync({
email: "[email protected]",
password: "password",
confirmPassword: "password",
});
});

test("custom path", async () => {
const result = await z
.object({
Expand Down
24 changes: 20 additions & 4 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,16 @@ export abstract class ZodType<
/** The .check method has been removed in Zod 3. For details see https://github.com/colinhacks/zod/tree/v3. */
check!: never;

refine<Func extends (arg: Output) => any>(
check: Func,
refine<RefinedOutput extends Output>(
check: (arg: Output) => arg is RefinedOutput,
message?: string | CustomErrorParams | ((arg: Output) => CustomErrorParams)
): ZodEffects<this, RefinedOutput, RefinedOutput>;
refine(
check: (arg: Output) => unknown | Promise<unknown>,
message?: string | CustomErrorParams | ((arg: Output) => CustomErrorParams)
): ZodEffects<this, Output, Input>;
refine(
check: (arg: Output) => unknown,
message?: string | CustomErrorParams | ((arg: Output) => CustomErrorParams)
): ZodEffects<this, Output, Input> {
const getIssueProperties: any = (val: Output) => {
Expand Down Expand Up @@ -264,8 +272,16 @@ export abstract class ZodType<
});
}

refinement<RefinedOutput extends Output>(
check: (arg: Output) => arg is RefinedOutput,
refinementData: IssueData | ((arg: Output, ctx: RefinementCtx) => IssueData)
): ZodEffects<this, RefinedOutput, RefinedOutput>;
refinement(
check: (arg: Output) => any,
check: (arg: Output) => boolean,
refinementData: IssueData | ((arg: Output, ctx: RefinementCtx) => IssueData)
): ZodEffects<this, Output, Input>;
refinement(
check: (arg: Output) => unknown,
refinementData: IssueData | ((arg: Output, ctx: RefinementCtx) => IssueData)
): ZodEffects<this, Output, Input> {
return this._refinement((val, ctx) => {
Expand All @@ -289,7 +305,7 @@ export abstract class ZodType<
schema: this,
typeName: ZodFirstPartyTypeKind.ZodEffects,
effect: { type: "refinement", refinement },
}) as any;
});
}
superRefine = this._refinement;

Expand Down

0 comments on commit fdd7084

Please sign in to comment.