Skip to content

Commit b3c7b64

Browse files
author
Christopher Doris
committed
Merge branch 'main' into pandas
2 parents 8841614 + 255f230 commit b3c7b64

File tree

2 files changed

+30
-13
lines changed

2 files changed

+30
-13
lines changed

src/convert.jl

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,7 @@ function pyconvert_typename(t::Py)
155155
return "$m:$n"
156156
end
157157

158-
function pyconvert_get_rules(type::Type, pytype::Py)
159-
@nospecialize type
160-
158+
function _pyconvert_get_rules(pytype::Py)
161159
pyisin(x, ys) = any(pyis(x, y) for y in ys)
162160

163161
# get the MROs of all base types we are considering
@@ -253,7 +251,25 @@ function pyconvert_get_rules(type::Type, pytype::Py)
253251
order = sort(axes(rules, 1), by = i -> (rules[i].priority, -i), rev = true)
254252
rules = rules[order]
255253

256-
# TODO: everything up to here does not depend on the julia type and could be cached
254+
@debug "pyconvert" pytype mro=join(mro, " ")
255+
return rules
256+
end
257+
258+
const PYCONVERT_PREFERRED_TYPE = Dict{Py,Type}()
259+
260+
pyconvert_preferred_type(pytype::Py) = get!(PYCONVERT_PREFERRED_TYPE, pytype) do
261+
if pyissubclass(pytype, pybuiltins.int)
262+
Union{Int,BigInt}
263+
else
264+
_pyconvert_get_rules(pytype)[1].type
265+
end
266+
end
267+
268+
function pyconvert_get_rules(type::Type, pytype::Py)
269+
@nospecialize type
270+
271+
# this could be cached
272+
rules = _pyconvert_get_rules(pytype)
257273

258274
# intersect rules with type
259275
rules = PyConvertRule[PyConvertRule(typeintersect(rule.type, type), rule.func, rule.priority) for rule in rules]
@@ -267,7 +283,7 @@ function pyconvert_get_rules(type::Type, pytype::Py)
267283
# filter out repeated rules
268284
rules = [rule for (i, rule) in enumerate(rules) if !any((rule.func === rules[j].func) && ((rule.type) <: (rules[j].type)) for j in 1:(i-1))]
269285

270-
@debug "pyconvert" type pytype mro=join(mro, " ") rules
286+
@debug "pyconvert" type rules
271287
return Function[pyconvert_fix(rule.type, rule.func) for rule in rules]
272288
end
273289

src/pywrap/PyArray.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ function pyarray_make(::Type{A}, x::Py, info::PyArraySource, ::Type{PyArray{T0,N
124124
if R0 == R1
125125
R = R1
126126
R == R′ || error("incorrect R, got $R, should be $R′")
127-
elseif T0 == T1 && T1 in (Bool, Int8, Int16, Int32, Int64, Int128, UInt8, UInt16, UInt32, UInt64, UInt128, Float16, Float32, Float64, Complex{Float16}, Complex{Float32}, Complex{Float64})
127+
elseif T0 == T1 && T1 in (Bool, Int8, Int16, Int32, Int64, Int128, UInt8, UInt16, UInt32, UInt64, UInt128, Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64)
128128
R = T1
129129
R == R′ || error("incorrect R, got $R, should be $R′")
130130
R <: R1 || error("R out of bounds, got $R, should be <: $R1")
@@ -431,19 +431,20 @@ pyarray_offset(x::PyArray{T,N}, i::Vararg{Int,N}) where {T,N} = sum((i .- 1) .*
431431
pyarray_offset(x::PyArray{T,0}) where {T} = 0
432432

433433
pyarray_load(::Type{R}, p::Ptr{R}) where {R} = unsafe_load(p)
434-
pyarray_load(::Type{Py}, p::Ptr{UnsafePyObject}) = begin
435-
o = unsafe_load(p)
436-
o.ptr == C_NULL ? Py(nothing) : pynew(incref(o.ptr))
434+
pyarray_load(::Type{T}, p::Ptr{UnsafePyObject}) where {T} = begin
435+
u = unsafe_load(p)
436+
o = u.ptr == C_NULL ? Py(nothing) : pynew(incref(u.ptr))
437+
T == Py ? o : pyconvert_and_del(T, o)
437438
end
438439

439440
pyarray_store!(p::Ptr{R}, x::R) where {R} = unsafe_store!(p, x)
440-
pyarray_store!(p::Ptr{UnsafePyObject}, x::Py) = begin
441+
pyarray_store!(p::Ptr{UnsafePyObject}, x) = @autopy x begin
441442
decref(unsafe_load(p).ptr)
442-
unsafe_store!(p, UnsafePyObject(GC.@preserve x incref(getptr(x))))
443+
unsafe_store!(p, UnsafePyObject(incref(getptr(x_))))
443444
end
444445

445446
pyarray_get_T(::Type{R}, ::Type{T0}, ::Type{T1}) where {R,T0,T1} = T0 <: R <: T1 ? R : error("not possible")
446-
pyarray_get_T(::Type{UnsafePyObject}, ::Type{T0}, ::Type{T1}) where {T0,T1} = T0 <: Py <: T1 ? Py : T0 <: UnsafePyObject <: T1 ? UnsafePyObject : error("not possible")
447+
pyarray_get_T(::Type{UnsafePyObject}, ::Type{T0}, ::Type{T1}) where {T0,T1} = T0 <: Py <: T1 ? Py : T1
447448

448449
pyarray_check_T(::Type{T}, ::Type{R}) where {T,R} = T == R ? nothing : error("invalid eltype T=$T for raw eltype R=$R")
449-
pyarray_check_T(::Type{Py}, ::Type{UnsafePyObject}) = nothing
450+
pyarray_check_T(::Type{T}, ::Type{UnsafePyObject}) where {T} = nothing

0 commit comments

Comments
 (0)