Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions compiler/fory_compiler/generators/java.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def generate_bytes_methods(self, class_name: str) -> List[str]:
PrimitiveKind.UINT64: "long",
PrimitiveKind.VAR_UINT64: "long",
PrimitiveKind.TAGGED_UINT64: "long",
PrimitiveKind.FLOAT16: "float",
PrimitiveKind.FLOAT16: "Float16",
PrimitiveKind.FLOAT32: "float",
PrimitiveKind.FLOAT64: "double",
PrimitiveKind.STRING: "String",
Expand Down Expand Up @@ -255,7 +255,7 @@ def generate_bytes_methods(self, class_name: str) -> List[str]:
PrimitiveKind.UINT64: "Long",
PrimitiveKind.VAR_UINT64: "Long",
PrimitiveKind.TAGGED_UINT64: "Long",
PrimitiveKind.FLOAT16: "Float",
PrimitiveKind.FLOAT16: "Float16",
PrimitiveKind.FLOAT32: "Float",
PrimitiveKind.FLOAT64: "Double",
PrimitiveKind.ANY: "Object",
Expand All @@ -278,7 +278,7 @@ def generate_bytes_methods(self, class_name: str) -> List[str]:
PrimitiveKind.UINT64: "long[]",
PrimitiveKind.VAR_UINT64: "long[]",
PrimitiveKind.TAGGED_UINT64: "long[]",
PrimitiveKind.FLOAT16: "float[]",
PrimitiveKind.FLOAT16: "Float16[]",
PrimitiveKind.FLOAT32: "float[]",
PrimitiveKind.FLOAT64: "double[]",
}
Expand Down Expand Up @@ -1165,6 +1165,8 @@ def collect_type_imports(
imports.add("java.time.LocalDate")
elif field_type.kind == PrimitiveKind.TIMESTAMP:
imports.add("java.time.Instant")
elif field_type.kind == PrimitiveKind.FLOAT16:
imports.add("org.apache.fory.type.Float16")

elif isinstance(field_type, ListType):
# Primitive arrays don't need List import
Expand Down Expand Up @@ -1379,7 +1381,11 @@ def generate_equals_method(self, message: Message) -> List[str]:
)
elif isinstance(field.field_type, PrimitiveType):
kind = field.field_type.kind
if kind in (PrimitiveKind.FLOAT32,):
if kind in (PrimitiveKind.FLOAT16,):
comparisons.append(
f"{field_name}.equalsValue(that.{field_name})"
)
elif kind in (PrimitiveKind.FLOAT32,):
comparisons.append(
f"Float.compare({field_name}, that.{field_name}) == 0"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
import org.apache.fory.codegen.Expression.Literal;
import org.apache.fory.codegen.Expression.Reference;
import org.apache.fory.codegen.Expression.Return;
import org.apache.fory.codegen.Expression.StaticInvoke;
import org.apache.fory.codegen.Expression.Variable;
import org.apache.fory.codegen.Expression.While;
import org.apache.fory.codegen.ExpressionUtils;
Expand Down Expand Up @@ -137,6 +138,7 @@
import org.apache.fory.serializer.collection.MapLikeSerializer;
import org.apache.fory.type.Descriptor;
import org.apache.fory.type.DispatchId;
import org.apache.fory.type.Float16;
import org.apache.fory.type.GenericType;
import org.apache.fory.type.TypeUtils;
import org.apache.fory.type.Types;
Expand Down Expand Up @@ -456,7 +458,7 @@ private Expression serializeForNotNullForField(
Expression inputObject, Expression buffer, Descriptor descriptor, Expression serializer) {
TypeRef<?> typeRef = descriptor.getTypeRef();
Class<?> clz = getRawType(typeRef);
if (isPrimitive(clz) || isBoxed(clz)) {
if (isPrimitiveLikeDescriptor(descriptor, clz)) {
return serializePrimitiveField(inputObject, buffer, descriptor);
} else {
if (clz == String.class) {
Expand Down Expand Up @@ -510,6 +512,8 @@ private Expression serializePrimitiveField(
return new Invoke(buffer, "writeFloat32", inputObject);
case DispatchId.FLOAT64:
return new Invoke(buffer, "writeFloat64", inputObject);
case DispatchId.FLOAT16:
return new Invoke(buffer, "writeInt16", new Invoke(inputObject, "toBits", SHORT_TYPE));
default:
throw new IllegalStateException("Unsupported dispatchId: " + dispatchId);
}
Expand Down Expand Up @@ -667,9 +671,19 @@ protected boolean useMapSerialization(Class<?> type) {
}

protected int getNumericDescriptorDispatchId(Descriptor descriptor) {
int dispatchId =
descriptorDispatchId.computeIfAbsent(descriptor, d -> DispatchId.getDispatchId(fory, d));
Class<?> rawType = descriptor.getRawType();
Preconditions.checkArgument(TypeUtils.unwrap(rawType).isPrimitive());
return descriptorDispatchId.computeIfAbsent(descriptor, d -> DispatchId.getDispatchId(fory, d));
Preconditions.checkArgument(
TypeUtils.unwrap(rawType).isPrimitive() || dispatchId == DispatchId.FLOAT16);
return dispatchId;
}

private boolean isPrimitiveLikeDescriptor(Descriptor descriptor, Class<?> rawType) {
if (isPrimitive(rawType) || isBoxed(rawType)) {
return true;
}
return rawType == Float16.class;
}

/**
Expand Down Expand Up @@ -1994,7 +2008,7 @@ private Expression deserializeForNotNullForField(
Expression buffer, Descriptor descriptor, Expression serializer) {
TypeRef<?> typeRef = descriptor.getTypeRef();
Class<?> cls = getRawType(typeRef);
if (isPrimitive(cls) || isBoxed(cls)) {
if (isPrimitiveLikeDescriptor(descriptor, cls)) {
return deserializePrimitiveField(buffer, descriptor);
} else {
if (cls == String.class) {
Expand Down Expand Up @@ -2067,6 +2081,9 @@ private Expression deserializePrimitiveField(Expression buffer, Descriptor descr
return isPrimitive
? readFloat64(buffer)
: new Invoke(buffer, readFloat64Func(), DOUBLE_TYPE);
case DispatchId.FLOAT16:
return new StaticInvoke(
Float16.class, "fromBits", TypeRef.of(Float16.class), readInt16(buffer));
default:
throw new IllegalStateException("Unsupported dispatchId: " + dispatchId);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.fory.collection;

import java.util.AbstractList;
import java.util.Arrays;
import java.util.Objects;
import java.util.RandomAccess;
import org.apache.fory.type.Float16;

public final class Float16List extends AbstractList<Float16> implements RandomAccess {
private static final int DEFAULT_CAPACITY = 10;

private short[] array;
private int size;

public Float16List() {
this(DEFAULT_CAPACITY);
}

public Float16List(int initialCapacity) {
if (initialCapacity < 0) {
throw new IllegalArgumentException("Illegal capacity: " + initialCapacity);
}
this.array = new short[initialCapacity];
this.size = 0;
}

public Float16List(short[] array) {
this.array = array;
this.size = array.length;
}

@Override
public Float16 get(int index) {
checkIndex(index);
return Float16.fromBits(array[index]);
}

@Override
public int size() {
return size;
}

@Override
public Float16 set(int index, Float16 element) {
checkIndex(index);
Objects.requireNonNull(element, "element");
short prev = array[index];
array[index] = element.toBits();
return Float16.fromBits(prev);
}

public void set(int index, short bits) {
checkIndex(index);
array[index] = bits;
}

public void set(int index, float value) {
checkIndex(index);
array[index] = Float16.valueOf(value).toBits();
}

@Override
public void add(int index, Float16 element) {
checkPositionIndex(index);
ensureCapacity(size + 1);
System.arraycopy(array, index, array, index + 1, size - index);
array[index] = element.toBits();
size++;
modCount++;
}

@Override
public boolean add(Float16 element) {
Objects.requireNonNull(element, "element");
ensureCapacity(size + 1);
array[size++] = element.toBits();
modCount++;
return true;
}

public boolean add(short bits) {
ensureCapacity(size + 1);
array[size++] = bits;
modCount++;
return true;
}

public boolean add(float value) {
ensureCapacity(size + 1);
array[size++] = Float16.valueOf(value).toBits();
modCount++;
return true;
}

public float getFloat(int index) {
checkIndex(index);
return Float16.fromBits(array[index]).toFloat();
}

public short getShort(int index) {
checkIndex(index);
return array[index];
}

public boolean hasArray() {
return array != null;
}

public short[] getArray() {
return array;
}

public short[] copyArray() {
return Arrays.copyOf(array, size);
}

private void ensureCapacity(int minCapacity) {
if (array.length >= minCapacity) {
return;
}
int newCapacity = array.length + (array.length >> 1) + 1;
if (newCapacity < minCapacity) {
newCapacity = minCapacity;
}
array = Arrays.copyOf(array, newCapacity);
}

private void checkIndex(int index) {
if (index < 0 || index >= size) {
throw new IndexOutOfBoundsException("Index: " + index + ", Size: " + size);
}
}

private void checkPositionIndex(int index) {
if (index < 0 || index > size) {
throw new IndexOutOfBoundsException("Index: " + index + ", Size: " + size);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.apache.fory.type.Descriptor;
import org.apache.fory.type.DescriptorGrouper;
import org.apache.fory.type.DispatchId;
import org.apache.fory.type.Float16;
import org.apache.fory.type.Generics;
import org.apache.fory.type.unsigned.Uint16;
import org.apache.fory.type.unsigned.Uint32;
Expand Down Expand Up @@ -368,6 +369,9 @@ static void writeNotPrimitiveFieldValue(
case DispatchId.FLOAT64:
buffer.writeFloat64((Double) fieldValue);
return;
case DispatchId.FLOAT16:
buffer.writeInt16(((Float16) fieldValue).toBits());
return;
default:
binding.writeField(fieldInfo, RefMode.NONE, buffer, fieldValue);
}
Expand Down Expand Up @@ -570,6 +574,8 @@ private static Object readNotNullBuildInFieldValue(
return buffer.readFloat32();
case DispatchId.FLOAT64:
return buffer.readFloat64();
case DispatchId.FLOAT16:
return Float16.fromBits(buffer.readInt16());
case DispatchId.STRING:
return binding.fory.readJavaString(buffer);
default:
Expand Down Expand Up @@ -788,6 +794,9 @@ private static void readNotPrimitiveFieldValue(
case DispatchId.FLOAT64:
fieldAccessor.putObject(targetObject, buffer.readFloat64());
return;
case DispatchId.FLOAT16:
fieldAccessor.putObject(targetObject, Float16.fromBits(buffer.readInt16()));
return;
case DispatchId.STRING:
fieldAccessor.putObject(targetObject, binding.fory.readJavaString(buffer));
return;
Expand Down Expand Up @@ -944,6 +953,7 @@ private static void copySetNotPrimitiveField(
case DispatchId.TAGGED_UINT64:
case DispatchId.FLOAT32:
case DispatchId.FLOAT64:
case DispatchId.FLOAT16:
case DispatchId.STRING:
Platform.putObject(newObj, fieldOffset, Platform.getObject(originObj, fieldOffset));
break;
Expand Down Expand Up @@ -1006,6 +1016,7 @@ private Object copyNotPrimitiveField(Object targetObject, long fieldOffset, int
case DispatchId.VAR_UINT64:
case DispatchId.TAGGED_UINT64:
case DispatchId.FLOAT64:
case DispatchId.FLOAT16:
case DispatchId.STRING:
return Platform.getObject(targetObject, fieldOffset);
default:
Expand Down
Loading
Loading