diff --git a/packages/orm/src/client/crud/dialects/base-dialect.ts b/packages/orm/src/client/crud/dialects/base-dialect.ts index 1f5102121..adb88bd6c 100644 --- a/packages/orm/src/client/crud/dialects/base-dialect.ts +++ b/packages/orm/src/client/crud/dialects/base-dialect.ts @@ -84,6 +84,22 @@ export abstract class BaseCrudDialect { // #endregion + // #region type mapping + + /** + * Maps a ZModel type to the corresponding SQL type for this dialect. + */ + protected abstract getSqlType(zmodelType: string): string | undefined; + + /** + * Checks if a field has a native database type attribute (e.g., `@db.Uuid`). + */ + protected hasNativeTypeAttribute(fieldDef: FieldDef): boolean { + return !!fieldDef.attributes?.some((a) => a.name.startsWith('@db.')); + } + + // #endregion + // #region value transformation /** @@ -1134,7 +1150,7 @@ export abstract class BaseCrudDialect { const descendants = getDelegateDescendantModels(this.schema, model); for (const subModel of descendants) { result = this.buildDelegateJoin(model, modelAlias, subModel.name, result); - result = result.select((eb) => { + result = result.select(() => { const jsonObject: Record> = {}; for (const field of Object.keys(subModel.fields)) { if ( @@ -1143,7 +1159,7 @@ export abstract class BaseCrudDialect { ) { continue; } - jsonObject[field] = eb.ref(`${subModel.name}.${field}`); + jsonObject[field] = this.fieldRef(subModel.name, field, subModel.name); } return this.buildJsonObject(jsonObject).as(`${DELEGATE_JOINED_FIELD_PREFIX}${subModel.name}`); }); @@ -1344,7 +1360,19 @@ export abstract class BaseCrudDialect { if (!fieldDef.computed) { // regular field - return this.eb.ref(modelAlias ? `${modelAlias}.${field}` : field); + const ref = modelAlias ? `${modelAlias}.${field}` : field; + + // if the field has a native database type annotation (e.g., @db.Uuid), cast it + // back to the base SQL type to avoid type mismatch in comparisons + if (this.hasNativeTypeAttribute(fieldDef)) { + const sqlType = this.getSqlType(fieldDef.type); + if (sqlType) { + const castType = fieldDef.array ? sql`${sql.raw(sqlType)}[]` : sql.raw(sqlType); + return sql`CAST(${sql.ref(ref)} AS ${castType})`; + } + } + + return this.eb.ref(ref); } else { // computed field if (!inlineComputedField) { diff --git a/packages/orm/src/client/crud/dialects/mysql.ts b/packages/orm/src/client/crud/dialects/mysql.ts index 6e444227b..1954ff2d7 100644 --- a/packages/orm/src/client/crud/dialects/mysql.ts +++ b/packages/orm/src/client/crud/dialects/mysql.ts @@ -16,7 +16,7 @@ import type { BuiltinType, FieldDef, SchemaDef } from '../../../schema'; import type { SortOrder } from '../../crud-types'; import { createInvalidInputError, createNotSupportedError } from '../../errors'; import type { ClientOptions } from '../../options'; -import { isTypeDef } from '../../query-utils'; +import { isEnum, isTypeDef } from '../../query-utils'; import { LateralJoinDialectBase } from './lateral-join-dialect-base'; export class MySqlCrudDialect extends LateralJoinDialectBase { @@ -318,6 +318,23 @@ export class MySqlCrudDialect extends LateralJoinDiale ); } + protected override getSqlType(zmodelType: string) { + if (isEnum(this.schema, zmodelType)) { + return 'varchar(191)'; + } + return match(zmodelType) + .with('String', () => 'char') + .with('Boolean', () => 'unsigned') + .with('Int', () => 'signed') + .with('BigInt', () => 'signed') + .with('Float', () => 'double') + .with('Decimal', () => 'decimal(65,30)') + .with('DateTime', () => 'datetime(3)') + .with('Bytes', () => 'binary') + .with('Json', () => 'json') + .otherwise(() => undefined); + } + override getStringCasingBehavior() { // MySQL LIKE is case-insensitive by default (depends on collation), no ILIKE support return { supportsILike: false, likeCaseSensitive: false }; diff --git a/packages/orm/src/client/crud/dialects/postgresql.ts b/packages/orm/src/client/crud/dialects/postgresql.ts index 5f962dbb5..2ace46364 100644 --- a/packages/orm/src/client/crud/dialects/postgresql.ts +++ b/packages/orm/src/client/crud/dialects/postgresql.ts @@ -281,7 +281,11 @@ export class PostgresCrudDialect extends LateralJoinDi override buildArrayValue(values: Expression[], elemType: string): AliasableExpression { const arr = sql`ARRAY[${sql.join(values, sql.raw(','))}]`; const mappedType = this.getSqlType(elemType); - return this.eb.cast(arr, sql`${sql.raw(mappedType)}[]`); + if (mappedType) { + return this.eb.cast(arr, sql`${sql.raw(mappedType)}[]`); + } else { + return arr; + } } override buildArrayContains( @@ -293,7 +297,7 @@ export class PostgresCrudDialect extends LateralJoinDi const arrayExpr = sql`ARRAY[${value}]`; if (elemType) { const mappedType = this.getSqlType(elemType); - const typedArray = this.eb.cast(arrayExpr, sql`${sql.raw(mappedType)}[]`); + const typedArray = mappedType ? this.eb.cast(arrayExpr, sql`${sql.raw(mappedType)}[]`) : arrayExpr; return this.eb(field, '@>', typedArray); } else { return this.eb(field, '@>', arrayExpr); @@ -357,25 +361,22 @@ export class PostgresCrudDialect extends LateralJoinDi ); } - private getSqlType(zmodelType: string) { + protected override getSqlType(zmodelType: string) { if (isEnum(this.schema, zmodelType)) { // reduce enum to text for type compatibility return 'text'; } else { - return ( - match(zmodelType) - .with('String', () => 'text') - .with('Boolean', () => 'boolean') - .with('Int', () => 'integer') - .with('BigInt', () => 'bigint') - .with('Float', () => 'double precision') - .with('Decimal', () => 'decimal') - .with('DateTime', () => 'timestamp') - .with('Bytes', () => 'bytea') - .with('Json', () => 'jsonb') - // fallback to text - .otherwise(() => 'text') - ); + return match(zmodelType) + .with('String', () => 'text') + .with('Boolean', () => 'boolean') + .with('Int', () => 'integer') + .with('BigInt', () => 'bigint') + .with('Float', () => 'double precision') + .with('Decimal', () => 'decimal(65,30)') + .with('DateTime', () => 'timestamp(3)') + .with('Bytes', () => 'bytea') + .with('Json', () => 'jsonb') + .otherwise(() => undefined); } } @@ -414,8 +415,12 @@ export class PostgresCrudDialect extends LateralJoinDi .select( fields.map((f, i) => { const mappedType = this.getSqlType(f.type); - const castType = f.array ? sql`${sql.raw(mappedType)}[]` : sql.raw(mappedType); - return this.eb.cast(sql.ref(`$values.column${i + 1}`), castType).as(f.name); + if (mappedType) { + const castType = f.array ? sql`${sql.raw(mappedType)}[]` : sql.raw(mappedType); + return this.eb.cast(sql.ref(`$values.column${i + 1}`), castType).as(f.name); + } else { + return sql.ref(`$values.column${i + 1}`).as(f.name); + } }), ); } diff --git a/packages/orm/src/client/crud/dialects/sqlite.ts b/packages/orm/src/client/crud/dialects/sqlite.ts index 95e2910f8..73712dd57 100644 --- a/packages/orm/src/client/crud/dialects/sqlite.ts +++ b/packages/orm/src/client/crud/dialects/sqlite.ts @@ -22,6 +22,7 @@ import { getDelegateDescendantModels, getManyToManyRelation, getRelationForeignKeyFieldPairs, + isEnum, requireField, requireIdFields, requireModel, @@ -488,6 +489,23 @@ export class SqliteCrudDialect extends BaseCrudDialect return this.eb.fn('trim', [expression, sql.lit('"')]) as unknown as T; } + protected override getSqlType(zmodelType: string) { + if (isEnum(this.schema, zmodelType)) { + return 'text'; + } + return match(zmodelType) + .with('String', () => 'text') + .with('Boolean', () => 'integer') + .with('Int', () => 'integer') + .with('BigInt', () => 'integer') + .with('Float', () => 'real') + .with('Decimal', () => 'decimal') + .with('DateTime', () => 'numeric') + .with('Bytes', () => 'blob') + .with('Json', () => 'jsonb') + .otherwise(() => undefined); + } + override getStringCasingBehavior() { // SQLite `LIKE` is case-insensitive, and there is no `ILIKE` return { supportsILike: false, likeCaseSensitive: false }; diff --git a/packages/plugins/policy/src/expression-transformer.ts b/packages/plugins/policy/src/expression-transformer.ts index 2a237c5cb..afa299be2 100644 --- a/packages/plugins/policy/src/expression-transformer.ts +++ b/packages/plugins/policy/src/expression-transformer.ts @@ -908,7 +908,7 @@ export class ExpressionTransformer { const fieldDef = QueryUtils.requireField(this.schema, context.modelOrType, column); if (!fieldDef.originModel || fieldDef.originModel === context.modelOrType) { - return ReferenceNode.create(ColumnNode.create(column), TableNode.create(tableName)); + return this.dialect.fieldRef(context.modelOrType, column, tableName, false).toOperationNode(); } return this.buildDelegateBaseFieldSelect(context.modelOrType, tableName, column, fieldDef.originModel); @@ -936,7 +936,9 @@ export class ExpressionTransformer { kind: 'SelectQueryNode', from: FromNode.create([TableNode.create(baseModel)]), selections: [ - SelectionNode.create(ReferenceNode.create(ColumnNode.create(field), TableNode.create(baseModel))), + SelectionNode.create( + this.dialect.fieldRef(baseModel, field, baseModel, false).toOperationNode() + ), ], where: WhereNode.create( conjunction( diff --git a/packages/plugins/policy/src/policy-handler.ts b/packages/plugins/policy/src/policy-handler.ts index 968703957..c52732ab4 100644 --- a/packages/plugins/policy/src/policy-handler.ts +++ b/packages/plugins/policy/src/policy-handler.ts @@ -311,7 +311,18 @@ export class PolicyHandler extends OperationNodeTransf () => new ExpressionWrapper(beforeUpdateTable!).as('$before'), (join) => { const idFields = QueryUtils.requireIdFields(this.client.$schema, model); - return idFields.reduce((acc, f) => acc.onRef(`${model}.${f}`, '=', `$before.${f}`), join); + const eb = expressionBuilder(); + return idFields.reduce( + (acc, f) => + acc.on(() => + eb( + this.dialect.fieldRef(model, f, model, false), + '=', + this.dialect.fieldRef(model, f, '$before', false), + ), + ), + join, + ); }, ), ); diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index c5ac83d9e..a28d48448 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -1151,6 +1151,9 @@ importers: decimal.js: specifier: 'catalog:' version: 10.6.0 + uuid: + specifier: ^11.0.5 + version: 11.0.5 devDependencies: '@types/node': specifier: 'catalog:' @@ -7360,6 +7363,7 @@ packages: prebuild-install@7.1.3: resolution: {integrity: sha512-8Mf2cbV7x1cXPUILADGI3wuhfqWvtiLA1iclTDbFRZkgRQS0NqsPZphna9V+HyTEadheuPmjaJMsbzKQFOzLug==} engines: {node: '>=10'} + deprecated: No longer maintained. Please contact the author of the relevant native addon; alternatives are available. hasBin: true prelude-ls@1.2.1: @@ -13234,7 +13238,7 @@ snapshots: eslint: 9.29.0(jiti@2.6.1) eslint-import-resolver-node: 0.3.9 eslint-import-resolver-typescript: 3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint@9.29.0(jiti@2.6.1)))(eslint@9.29.0(jiti@2.6.1)) - eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint@9.29.0(jiti@2.6.1)))(eslint@9.29.0(jiti@2.6.1)))(eslint@9.29.0(jiti@2.6.1)) + eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@9.29.0(jiti@2.6.1)) eslint-plugin-jsx-a11y: 6.10.2(eslint@9.29.0(jiti@2.6.1)) eslint-plugin-react: 7.37.5(eslint@9.29.0(jiti@2.6.1)) eslint-plugin-react-hooks: 7.0.1(eslint@9.29.0(jiti@2.6.1)) @@ -13267,7 +13271,7 @@ snapshots: tinyglobby: 0.2.15 unrs-resolver: 1.11.1 optionalDependencies: - eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint@9.29.0(jiti@2.6.1)))(eslint@9.29.0(jiti@2.6.1)))(eslint@9.29.0(jiti@2.6.1)) + eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@9.29.0(jiti@2.6.1)) transitivePeerDependencies: - supports-color @@ -13282,7 +13286,7 @@ snapshots: transitivePeerDependencies: - supports-color - eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1(eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint@9.29.0(jiti@2.6.1)))(eslint@9.29.0(jiti@2.6.1)))(eslint@9.29.0(jiti@2.6.1)): + eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.46.2(eslint@9.29.0(jiti@2.6.1))(typescript@5.9.3))(eslint-import-resolver-typescript@3.10.1)(eslint@9.29.0(jiti@2.6.1)): dependencies: '@rtsao/scc': 1.1.0 array-includes: 3.1.9 diff --git a/tests/regression/package.json b/tests/regression/package.json index cdd511a7e..707c54afd 100644 --- a/tests/regression/package.json +++ b/tests/regression/package.json @@ -13,17 +13,18 @@ }, "dependencies": { "@zenstackhq/testtools": "workspace:*", - "decimal.js": "catalog:" + "decimal.js": "catalog:", + "uuid": "^11.0.5" }, "devDependencies": { + "@types/node": "catalog:", "@zenstackhq/cli": "workspace:*", "@zenstackhq/language": "workspace:*", - "@zenstackhq/schema": "workspace:*", "@zenstackhq/orm": "workspace:*", - "@zenstackhq/sdk": "workspace:*", "@zenstackhq/plugin-policy": "workspace:*", + "@zenstackhq/schema": "workspace:*", + "@zenstackhq/sdk": "workspace:*", "@zenstackhq/typescript-config": "workspace:*", - "@zenstackhq/vitest-config": "workspace:*", - "@types/node": "catalog:" + "@zenstackhq/vitest-config": "workspace:*" } } diff --git a/tests/regression/test/issue-2394.test.ts b/tests/regression/test/issue-2394.test.ts new file mode 100644 index 000000000..0b4b5df22 --- /dev/null +++ b/tests/regression/test/issue-2394.test.ts @@ -0,0 +1,44 @@ +import { createPolicyTestClient } from '@zenstackhq/testtools'; +import { v4 as uuid } from 'uuid'; +import { describe, expect, it } from 'vitest'; + +describe('Regression for issue #2394', () => { + const UUID_SCHEMA = ` +model Foo { + id String @id @db.Uuid @default(dbgenerated("gen_random_uuid()")) + x String + + @@allow('all', id == x) +} +`; + + it('works with policies', async () => { + const db = await createPolicyTestClient(UUID_SCHEMA, { + provider: 'postgresql', + usePrismaPush: true, + }); + + await db.$unuseAll().foo.create({ data: { x: uuid() } }); + await expect(db.foo.findMany()).toResolveTruthy(); + }); + + it('works with post-update policies', async () => { + const db = await createPolicyTestClient( + ` +model ExchangeRequest { + id String @id @default(dbgenerated("gen_random_uuid()")) @db.Uuid + status String + + @@allow('all', true) + @@deny('post-update', before().status == status) // triggers buildValuesTableSelect +} +`, + { provider: 'postgresql', usePrismaPush: true }, + ); + + const request = await db.exchangeRequest.create({ data: { status: 'pending' } }); + await expect( + db.exchangeRequest.update({ where: { id: request.id }, data: { status: 'done' } }), + ).toResolveTruthy(); + }); +});