diff --git a/vector/src/main/codegen/templates/AbstractFieldReader.java b/vector/src/main/codegen/templates/AbstractFieldReader.java index 25b071fab7..7e84323b64 100644 --- a/vector/src/main/codegen/templates/AbstractFieldReader.java +++ b/vector/src/main/codegen/templates/AbstractFieldReader.java @@ -108,6 +108,23 @@ public void copyAsField(String name, ${name}Writer writer) { } + + public void read(ExtensionHolder holder) { + fail("Extension"); + } + + public void read(int arrayIndex, ExtensionHolder holder) { + fail("RepeatedExtension"); + } + + public void copyAsValue(AbstractExtensionTypeWriter writer) { + fail("CopyAsValueExtension"); + } + + public void copyAsField(String name, AbstractExtensionTypeWriter writer) { + fail("CopyAsFieldExtension"); + } + public FieldReader reader(String name) { fail("reader(String name)"); return null; diff --git a/vector/src/main/codegen/templates/AbstractPromotableFieldWriter.java b/vector/src/main/codegen/templates/AbstractPromotableFieldWriter.java index 951edd5eee..2e7792fcfe 100644 --- a/vector/src/main/codegen/templates/AbstractPromotableFieldWriter.java +++ b/vector/src/main/codegen/templates/AbstractPromotableFieldWriter.java @@ -295,7 +295,7 @@ public MapWriter map(boolean keysSorted) { @Override public ExtensionWriter extension(ArrowType arrowType) { - return getWriter(MinorType.EXTENSIONTYPE).extension(arrowType); + return getWriter(MinorType.LIST).extension(arrowType); } @Override @@ -325,7 +325,7 @@ public MapWriter map(String name, boolean keysSorted) { @Override public ExtensionWriter extension(String name, ArrowType arrowType) { - return getWriter(MinorType.EXTENSIONTYPE).extension(name, arrowType); + return getWriter(MinorType.STRUCT).extension(name, arrowType); } <#list vv.types as type><#list type.minor as minor> diff --git a/vector/src/main/codegen/templates/BaseReader.java b/vector/src/main/codegen/templates/BaseReader.java index e75e8a2974..c52345af21 100644 --- a/vector/src/main/codegen/templates/BaseReader.java +++ b/vector/src/main/codegen/templates/BaseReader.java @@ -73,7 +73,7 @@ public interface RepeatedMapReader extends MapReader{ public interface ScalarReader extends <#list vv.types as type><#list type.minor as minor><#assign name = minor.class?cap_first /> ${name}Reader, - BaseReader {} + ExtensionReader, BaseReader {} interface ComplexReader{ StructReader rootAsStruct(); diff --git a/vector/src/main/codegen/templates/NullReader.java b/vector/src/main/codegen/templates/NullReader.java index 1d77248e96..88e6ea98ea 100644 --- a/vector/src/main/codegen/templates/NullReader.java +++ b/vector/src/main/codegen/templates/NullReader.java @@ -86,6 +86,10 @@ public void read(int arrayIndex, Nullable${name}Holder holder){ } + public void read(ExtensionHolder holder) { + holder.isSet = 0; + } + public int size(){ return 0; } diff --git a/vector/src/main/codegen/templates/PromotableWriter.java b/vector/src/main/codegen/templates/PromotableWriter.java index 8d7d57bb9d..d22eb00b2c 100644 --- a/vector/src/main/codegen/templates/PromotableWriter.java +++ b/vector/src/main/codegen/templates/PromotableWriter.java @@ -550,6 +550,10 @@ public void addExtensionTypeWriterFactory(ExtensionTypeWriterFactory factory) { getWriter(MinorType.EXTENSIONTYPE).addExtensionTypeWriterFactory(factory); } + public void addExtensionTypeWriterFactory(ExtensionTypeWriterFactory factory, ArrowType arrowType) { + getWriter(MinorType.EXTENSIONTYPE, arrowType).addExtensionTypeWriterFactory(factory); + } + @Override public void allocate() { getWriter().allocate(); diff --git a/vector/src/main/codegen/templates/UnionListWriter.java b/vector/src/main/codegen/templates/UnionListWriter.java index 9424533f29..94723e6c9d 100644 --- a/vector/src/main/codegen/templates/UnionListWriter.java +++ b/vector/src/main/codegen/templates/UnionListWriter.java @@ -53,6 +53,7 @@ public class Union${listName}Writer extends AbstractFieldWriter { private boolean inStruct = false; private boolean listStarted = false; private String structName; + private ArrowType extensionType; <#if listName == "LargeList" || listName == "LargeListView"> private static final long OFFSET_WIDTH = 8; <#else> @@ -203,8 +204,8 @@ public MapWriter map(String name, boolean keysSorted) { @Override public ExtensionWriter extension(ArrowType arrowType) { - writer.extension(arrowType); - return writer; + this.extensionType = arrowType; + return this; } @Override public ExtensionWriter extension(String name, ArrowType arrowType) { @@ -337,13 +338,17 @@ public void writeNull() { @Override public void writeExtension(Object value) { writer.writeExtension(value); + writer.setPosition(writer.idx() + 1); } + @Override public void addExtensionTypeWriterFactory(ExtensionTypeWriterFactory var1) { - writer.addExtensionTypeWriterFactory(var1); + writer.addExtensionTypeWriterFactory(var1, extensionType); } + public void write(ExtensionHolder var1) { writer.write(var1); + writer.setPosition(writer.idx() + 1); } <#list vv.types as type> diff --git a/vector/src/main/java/org/apache/arrow/vector/complex/reader/ExtensionReader.java b/vector/src/main/java/org/apache/arrow/vector/complex/reader/ExtensionReader.java new file mode 100644 index 0000000000..1ba7b27156 --- /dev/null +++ b/vector/src/main/java/org/apache/arrow/vector/complex/reader/ExtensionReader.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.vector.complex.reader; + +import org.apache.arrow.vector.holders.ExtensionHolder; + +/** Interface for reading extension types. Extends the functionality of {@link BaseReader}. */ +public interface ExtensionReader extends BaseReader { + + /** + * Reads to the given extension holder. + * + * @param holder the {@link ExtensionHolder} to read + */ + void read(ExtensionHolder holder); + + /** + * Reads and returns an object representation of the extension type. + * + * @return the object representation of the extension type + */ + Object readObject(); + + /** + * Checks if the current value is set. + * + * @return true if the value is set, false otherwise + */ + boolean isSet(); +} diff --git a/vector/src/test/java/org/apache/arrow/vector/TestListVector.java b/vector/src/test/java/org/apache/arrow/vector/TestListVector.java index 1d6fa39f9e..97a0451188 100644 --- a/vector/src/test/java/org/apache/arrow/vector/TestListVector.java +++ b/vector/src/test/java/org/apache/arrow/vector/TestListVector.java @@ -23,16 +23,22 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.UUID; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.util.AutoCloseables; import org.apache.arrow.vector.complex.BaseRepeatedValueVector; import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.impl.UnionListReader; import org.apache.arrow.vector.complex.impl.UnionListWriter; +import org.apache.arrow.vector.complex.impl.UuidWriterFactory; import org.apache.arrow.vector.complex.reader.FieldReader; +import org.apache.arrow.vector.complex.writer.BaseWriter.ExtensionWriter; +import org.apache.arrow.vector.holder.UuidHolder; import org.apache.arrow.vector.holders.DurationHolder; import org.apache.arrow.vector.holders.FixedSizeBinaryHolder; import org.apache.arrow.vector.holders.TimeStampMilliTZHolder; @@ -41,6 +47,7 @@ import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.UuidType; import org.apache.arrow.vector.util.TransferPair; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; @@ -1198,6 +1205,71 @@ public void testGetTransferPairWithField() { } } + @Test + public void testListVectorWithExtensionType() throws Exception { + final FieldType type = FieldType.nullable(new UuidType()); + try (final ListVector inVector = new ListVector("list", allocator, type, null)) { + UnionListWriter writer = inVector.getWriter(); + writer.allocate(); + writer.setPosition(0); + UUID u1 = UUID.randomUUID(); + UUID u2 = UUID.randomUUID(); + writer.startList(); + ExtensionWriter extensionWriter = writer.extension(new UuidType()); + extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); + extensionWriter.writeExtension(u1); + extensionWriter.writeExtension(u2); + writer.endList(); + + writer.setValueCount(1); + + FieldReader reader = inVector.getReader(); + assertTrue(reader.isSet(), "shouldn't be null"); + Object result = inVector.getObject(0); + ArrayList resultSet = (ArrayList) result; + assertEquals(2, resultSet.size()); + assertEquals(u1, resultSet.get(0)); + assertEquals(u2, resultSet.get(1)); + } + } + + @Test + public void testListVectorReaderForExtensionType() throws Exception { + final FieldType type = FieldType.nullable(new UuidType()); + try (final ListVector inVector = new ListVector("list", allocator, type, null)) { + UnionListWriter writer = inVector.getWriter(); + writer.allocate(); + writer.setPosition(0); + UUID u1 = UUID.randomUUID(); + UUID u2 = UUID.randomUUID(); + writer.startList(); + ExtensionWriter extensionWriter = writer.extension(new UuidType()); + extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); + extensionWriter.writeExtension(u1); + extensionWriter.writeExtension(u2); + writer.endList(); + + writer.setValueCount(1); + + UnionListReader reader = inVector.getReader(); + assertTrue(reader.isSet(), "shouldn't be null"); + reader.setPosition(0); + reader.next(); + FieldReader uuidReader = reader.reader(); + UuidHolder holder = new UuidHolder(); + uuidReader.read(holder); + ByteBuffer bb = ByteBuffer.wrap(holder.value); + UUID actualUuid = new UUID(bb.getLong(), bb.getLong()); + assertEquals(u1, actualUuid); + reader.next(); + uuidReader = reader.reader(); + uuidReader.read(holder); + bb = ByteBuffer.wrap(holder.value); + actualUuid = new UUID(bb.getLong(), bb.getLong()); + assertEquals(u2, actualUuid); + } + } + private void writeIntValues(UnionListWriter writer, int[] values) { writer.startList(); for (int v : values) { diff --git a/vector/src/test/java/org/apache/arrow/vector/UuidVector.java b/vector/src/test/java/org/apache/arrow/vector/UuidVector.java index 5c90d45f60..72ba4aa555 100644 --- a/vector/src/test/java/org/apache/arrow/vector/UuidVector.java +++ b/vector/src/test/java/org/apache/arrow/vector/UuidVector.java @@ -20,6 +20,9 @@ import java.util.UUID; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.util.hash.ArrowBufHasher; +import org.apache.arrow.vector.complex.impl.UuidReaderImpl; +import org.apache.arrow.vector.complex.reader.FieldReader; +import org.apache.arrow.vector.holder.UuidHolder; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.types.pojo.UuidType; @@ -79,11 +82,21 @@ public TransferPair makeTransferPair(ValueVector to) { return new TransferImpl((UuidVector) to); } + @Override + protected FieldReader getReaderImpl() { + return new UuidReaderImpl(this); + } + public void setSafe(int index, byte[] value) { getUnderlyingVector().setIndexDefined(index); getUnderlyingVector().setSafe(index, value); } + public void get(int index, UuidHolder holder) { + holder.value = getUnderlyingVector().get(index); + holder.isSet = 1; + } + public class TransferImpl implements TransferPair { UuidVector to; ValueVector targetUnderlyingVector; diff --git a/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java b/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java index 1556852c5a..7b8b1f9ef9 100644 --- a/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java +++ b/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java @@ -805,4 +805,29 @@ public void testExtensionType() throws Exception { assertEquals(u2, uuidVector.getObject(1)); } } + + @Test + public void testExtensionTypeForList() throws Exception { + try (final ListVector container = ListVector.empty(EMPTY_SCHEMA_PATH, allocator); + final UuidVector v = + (UuidVector) container.addOrGetVector(FieldType.nullable(new UuidType())).getVector(); + final PromotableWriter writer = new PromotableWriter(v, container)) { + UUID u1 = UUID.randomUUID(); + UUID u2 = UUID.randomUUID(); + container.allocateNew(); + container.setValueCount(1); + writer.addExtensionTypeWriterFactory(new UuidWriterFactory()); + + writer.setPosition(0); + writer.writeExtension(u1); + writer.setPosition(1); + writer.writeExtension(u2); + + container.setValueCount(2); + + UuidVector uuidVector = (UuidVector) container.getDataVector(); + assertEquals(u1, uuidVector.getObject(0)); + assertEquals(u2, uuidVector.getObject(1)); + } + } } diff --git a/vector/src/test/java/org/apache/arrow/vector/complex/impl/UuidReaderImpl.java b/vector/src/test/java/org/apache/arrow/vector/complex/impl/UuidReaderImpl.java new file mode 100644 index 0000000000..16dd734de8 --- /dev/null +++ b/vector/src/test/java/org/apache/arrow/vector/complex/impl/UuidReaderImpl.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.vector.complex.impl; + +import org.apache.arrow.vector.UuidVector; +import org.apache.arrow.vector.holder.UuidHolder; +import org.apache.arrow.vector.holders.ExtensionHolder; +import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.Field; + +public class UuidReaderImpl extends AbstractFieldReader { + + private final UuidVector vector; + + public UuidReaderImpl(UuidVector vector) { + super(); + this.vector = vector; + } + + @Override + public MinorType getMinorType() { + return vector.getMinorType(); + } + + @Override + public Field getField() { + return vector.getField(); + } + + @Override + public boolean isSet() { + return !vector.isNull(idx()); + } + + @Override + public void read(ExtensionHolder holder) { + vector.get(idx(), (UuidHolder) holder); + } + + @Override + public void read(int arrayIndex, ExtensionHolder holder) { + vector.get(arrayIndex, (UuidHolder) holder); + } + + @Override + public void copyAsValue(AbstractExtensionTypeWriter writer) { + UuidWriterImpl impl = (UuidWriterImpl) writer; + impl.vector.copyFromSafe(idx(), impl.idx(), vector); + } +} diff --git a/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java b/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java index 2745386db4..f374eb41e4 100644 --- a/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java +++ b/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java @@ -19,6 +19,7 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -31,6 +32,7 @@ import java.util.HashSet; import java.util.List; import java.util.Set; +import java.util.UUID; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; @@ -64,6 +66,7 @@ import org.apache.arrow.vector.complex.impl.UnionMapReader; import org.apache.arrow.vector.complex.impl.UnionReader; import org.apache.arrow.vector.complex.impl.UnionWriter; +import org.apache.arrow.vector.complex.impl.UuidWriterFactory; import org.apache.arrow.vector.complex.reader.BaseReader.StructReader; import org.apache.arrow.vector.complex.reader.BigIntReader; import org.apache.arrow.vector.complex.reader.FieldReader; @@ -71,9 +74,11 @@ import org.apache.arrow.vector.complex.reader.Float8Reader; import org.apache.arrow.vector.complex.reader.IntReader; import org.apache.arrow.vector.complex.writer.BaseWriter.ComplexWriter; +import org.apache.arrow.vector.complex.writer.BaseWriter.ExtensionWriter; import org.apache.arrow.vector.complex.writer.BaseWriter.ListWriter; import org.apache.arrow.vector.complex.writer.BaseWriter.MapWriter; import org.apache.arrow.vector.complex.writer.BaseWriter.StructWriter; +import org.apache.arrow.vector.holder.UuidHolder; import org.apache.arrow.vector.holders.DecimalHolder; import org.apache.arrow.vector.holders.DurationHolder; import org.apache.arrow.vector.holders.FixedSizeBinaryHolder; @@ -84,6 +89,7 @@ import org.apache.arrow.vector.holders.NullableTimeStampNanoTZHolder; import org.apache.arrow.vector.holders.TimeStampMilliTZHolder; import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeID; import org.apache.arrow.vector.types.pojo.ArrowType.Int; @@ -93,6 +99,7 @@ import org.apache.arrow.vector.types.pojo.ArrowType.Utf8; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.UuidType; import org.apache.arrow.vector.util.CallBack; import org.apache.arrow.vector.util.DecimalUtility; import org.apache.arrow.vector.util.JsonStringArrayList; @@ -2489,4 +2496,38 @@ public void unionWithVarCharAndBinaryHelpers() throws Exception { "row12", new String(vector.getLargeVarBinaryVector().get(11), StandardCharsets.UTF_8)); } } + + @Test + public void extensionWriterReader() throws Exception { + // test values + UUID u1 = UUID.randomUUID(); + + try (NonNullableStructVector parent = NonNullableStructVector.empty("parent", allocator)) { + // write + + ComplexWriter writer = new ComplexWriterImpl("root", parent); + StructWriter rootWriter = writer.rootAsStruct(); + + { + ExtensionWriter extensionWriter = rootWriter.extension("uuid1", new UuidType()); + extensionWriter.setPosition(0); + extensionWriter.addExtensionTypeWriterFactory(new UuidWriterFactory()); + extensionWriter.writeExtension(u1); + } + // read + StructReader rootReader = new SingleStructReaderImpl(parent).reader("root"); + { + FieldReader uuidReader = rootReader.reader("uuid1"); + uuidReader.setPosition(0); + UuidHolder uuidHolder = new UuidHolder(); + uuidReader.read(uuidHolder); + final ByteBuffer bb = ByteBuffer.wrap(uuidHolder.value); + UUID actualUuid = new UUID(bb.getLong(), bb.getLong()); + assertEquals(u1, actualUuid); + assertTrue(uuidReader.isSet()); + assertEquals(uuidReader.getMinorType(), MinorType.EXTENSIONTYPE); + assertInstanceOf(UuidType.class, uuidReader.getField().getFieldType().getType()); + } + } + } } diff --git a/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestSimpleWriter.java b/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestSimpleWriter.java index bf1b9b0dfa..269cff0670 100644 --- a/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestSimpleWriter.java +++ b/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestSimpleWriter.java @@ -30,6 +30,7 @@ import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.complex.impl.LargeVarBinaryWriterImpl; import org.apache.arrow.vector.complex.impl.LargeVarCharWriterImpl; +import org.apache.arrow.vector.complex.impl.UuidReaderImpl; import org.apache.arrow.vector.complex.impl.UuidWriterImpl; import org.apache.arrow.vector.complex.impl.VarBinaryWriterImpl; import org.apache.arrow.vector.complex.impl.VarCharWriterImpl; @@ -204,4 +205,23 @@ public void testWriteToExtensionVector() throws Exception { assertEquals(uuid, result); } } + + @Test + public void testReaderCopyAsValueExtensionVector() throws Exception { + try (UuidVector vector = new UuidVector("test", allocator); + UuidVector vectorForRead = new UuidVector("test2", allocator); + UuidWriterImpl writer = new UuidWriterImpl(vector)) { + UUID uuid = UUID.randomUUID(); + vectorForRead.setValueCount(1); + vectorForRead.set(0, uuid); + UuidReaderImpl reader = (UuidReaderImpl) vectorForRead.getReader(); + reader.copyAsValue(writer); + UuidReaderImpl reader2 = (UuidReaderImpl) vector.getReader(); + UuidHolder holder = new UuidHolder(); + reader2.read(0, holder); + final ByteBuffer bb = ByteBuffer.wrap(holder.value); + UUID actualUuid = new UUID(bb.getLong(), bb.getLong()); + assertEquals(uuid, actualUuid); + } + } }