Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 2 additions & 24 deletions conformance/third_party/conformance.exp
Original file line number Diff line number Diff line change
Expand Up @@ -3633,36 +3633,14 @@
{
"code": -2,
"column": 12,
"concise_description": "revealed type: (x: int) -> Class7[int]",
"description": "revealed type: (x: int) -> Class7[int]",
"concise_description": "revealed type: Overload[\n (x: int) -> Class7[int]\n (x: str) -> Class7[str]\n]",
"description": "revealed type: Overload[\n (x: int) -> Class7[int]\n (x: str) -> Class7[str]\n]",
"line": 163,
"name": "reveal-type",
"severity": "info",
"stop_column": 2,
"stop_line": 165
},
{
"code": -2,
"column": 12,
"concise_description": "assert_type(Class7[int], Class7[str]) failed",
"description": "assert_type(Class7[int], Class7[str]) failed",
"line": 167,
"name": "assert-type",
"severity": "error",
"stop_column": 33,
"stop_line": 167
},
{
"code": -2,
"column": 16,
"concise_description": "Argument `Literal['']` is not assignable to parameter `x` with type `int`",
"description": "Argument `Literal['']` is not assignable to parameter `x` with type `int`",
"line": 167,
"name": "bad-argument-type",
"severity": "error",
"stop_column": 18,
"stop_line": 167
},
{
"code": -2,
"column": 12,
Expand Down
1 change: 0 additions & 1 deletion conformance/third_party/conformance.result
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
"constructors_callable.py": [
"Line 186: Expected 1 errors",
"Line 197: Expected 1 errors",
"Line 167: Unexpected errors ['assert_type(Class7[int], Class7[str]) failed', \"Argument `Literal['']` is not assignable to parameter `x` with type `int`\"]",
"Line 185: Unexpected errors ['assert_type(Class8[Any], Class8[str]) failed']"
],
"constructors_consistency.py": [],
Expand Down
4 changes: 2 additions & 2 deletions conformance/third_party/results.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"pass": 121,
"fail": 17,
"pass_rate": 0.88,
"differences": 52,
"differences": 51,
"passing": [
"aliases_explicit.py",
"aliases_newtype.py",
Expand Down Expand Up @@ -131,7 +131,7 @@
"aliases_implicit.py": 5,
"annotations_forward_refs.py": 2,
"callables_annotation.py": 2,
"constructors_callable.py": 4,
"constructors_callable.py": 3,
"dataclasses_descriptors.py": 6,
"dataclasses_slots.py": 2,
"exceptions_context_managers.py": 2,
Expand Down
252 changes: 238 additions & 14 deletions pyrefly/lib/alt/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ use pyrefly_types::literal::Literal;
use pyrefly_types::quantified::Quantified;
use pyrefly_types::typed_dict::TypedDictInner;
use pyrefly_types::types::CalleeKind;
use pyrefly_types::types::Overload;
use pyrefly_types::types::TArgs;
use pyrefly_types::types::TParams;
use pyrefly_types::types::Union;
use pyrefly_util::owner::Owner;
use pyrefly_util::prelude::SliceExt;
use pyrefly_util::prelude::VecExt;
use ruff_python_ast::Arguments;
Expand All @@ -34,6 +36,7 @@ use crate::alt::callable::CallArg;
use crate::alt::callable::CallKeyword;
use crate::alt::callable::CallWithTypes;
use crate::alt::class::class_field::DescriptorBase;
use crate::alt::expr::TypeOrExpr;
use crate::alt::unwrap::HintRef;
use crate::binding::binding::Key;
use crate::config::error_kind::ErrorKind;
Expand Down Expand Up @@ -848,6 +851,168 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
}
}

/// When a generic function is called with an overloaded function argument (positional or
/// keyword), apply the call once per overload signature and return an overloaded result.
/// This preserves the overloaded type through the generic function rather than collapsing
/// to one overload.
///
/// Handles two cases:
/// - An already-overloaded function (`Type::Overload`): e.g. `copy(foo)` where `foo` is overloaded.
/// - A class with an overloaded constructor (`Type::ClassDef`): e.g. `accepts_callable(Class7)`
/// where `Class7.__init__` is overloaded.
///
/// Scans positional args first, then keyword args, for the first overloaded callable.
/// Only the first overloaded argument is expanded; supporting multiple overloaded arguments
/// would require a cartesian product of expansions, which is not currently implemented.
///
/// Returns `Some(Type::Overload(...))` if all signatures succeed without errors,
/// `None` if expansion is not applicable or any signature fails.
fn call_generic_with_overloaded_arg(
&self,
callable: Callable,
metadata: &FuncMetadata,
tparams: Option<&TParams>,
args: &[CallArg],
keywords: &[CallKeyword],
arguments_range: TextRange,
errors: &ErrorCollector,
hint: Option<HintRef>,
) -> Option<Type> {
/// Identifies which argument (positional or keyword) contains the overloaded type.
#[derive(Clone, Copy)]
enum OverloadSource {
Arg(usize),
Keyword(usize),
}

// Helper to extract an Overload from a type, handling both already-overloaded
// functions and classes whose constructors are overloaded.
let find_overload_in_type =
|ty: &Type, range: TextRange| -> Option<(Overload, TextRange)> {
match ty {
// Already-overloaded function value.
Type::Overload(overload) => Some((overload.clone(), range)),
// Class whose constructor is overloaded.
Type::ClassDef(cls) => {
let class_type = match self.instantiate(cls) {
Type::ClassType(ct) => ct,
_ => return None,
};
match self.constructor_to_callable(&class_type) {
Type::Overload(overload) => Some((overload, range)),
_ => None,
}
}
_ => None,
}
};

// Scan positional args first, then keyword args, for the first overloaded callable.
let (source, overload, arg_range) = args
.iter()
.enumerate()
.find_map(|(i, arg)| {
if let CallArg::Arg(TypeOrExpr::Type(ty, range)) = arg {
find_overload_in_type(ty, *range).map(|(o, r)| (OverloadSource::Arg(i), o, r))
} else {
None
}
})
.or_else(|| {
keywords.iter().enumerate().find_map(|(i, kw)| {
if let TypeOrExpr::Type(ty, range) = kw.value {
find_overload_in_type(ty, range)
.map(|(o, r)| (OverloadSource::Keyword(i), o, r))
} else {
None
}
})
})?;

// For each overload signature, call the generic function with that specific type.
// Store substituted types in an owner to give them a stable address.
let type_owner = Owner::<Type>::new();
let mut results: Vec<OverloadType> = Vec::new();
for sig in &overload.signatures {
let sig_type = sig.as_type();
let (modified_args, modified_kws): (Vec<CallArg<'_>>, Vec<CallKeyword<'_>>) =
match source {
OverloadSource::Arg(idx) => (
args.iter()
.enumerate()
.map(|(i, arg)| {
if i == idx {
CallArg::Arg(TypeOrExpr::Type(
type_owner.push(sig_type.clone()),
arg_range,
))
} else {
arg.clone()
}
})
.collect(),
keywords.to_vec(),
),
OverloadSource::Keyword(idx) => (
args.to_vec(),
keywords
.iter()
.enumerate()
.map(|(i, kw)| {
if i == idx {
CallKeyword {
range: kw.range,
arg: kw.arg,
value: TypeOrExpr::Type(
type_owner.push(sig_type.clone()),
arg_range,
),
}
} else {
kw.clone()
}
})
.collect(),
),
};

let call_errors = self.error_collector();
let res = self.callable_infer(
callable.clone(),
Some(&metadata.kind),
tparams,
None,
&modified_args,
&modified_kws,
arguments_range,
errors,
&call_errors,
None,
hint,
None,
);
if !call_errors.is_empty() {
return None;
}
// The result must be a callable type to be included in the overloaded result.
match res {
Type::Callable(c) => results.push(OverloadType::Function(Function {
signature: *c,
metadata: overload.metadata.as_ref().clone(),
})),
Type::Function(f) => results.push(OverloadType::Function(*f)),
_ => return None,
}
}

Vec1::try_from_vec(results).ok().map(|signatures| {
Type::Overload(Overload {
signatures,
metadata: overload.metadata.clone(),
})
})
}

fn call_infer_with_callee_range(
&self,
call_target: CallTarget,
Expand Down Expand Up @@ -1011,20 +1176,79 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
signature: callable,
metadata,
},
)) => self.callable_infer(
callable,
Some(&metadata.kind),
tparams.as_deref(),
None,
args,
keywords,
arguments_range,
errors,
errors,
context,
hint,
ctor_targs,
),
)) => {
// For generic functions, check if any argument is an overloaded function.
// If so, apply the call once per overload signature and return an overloaded
// result, preserving the overloaded type through the generic function.
//
// We pre-evaluate all args to types here so that:
// (1) call_generic_with_overloaded_arg can scan for TypeOrExpr::Type variants
// (expression args have not been evaluated yet at this point), and
// (2) if no overload expansion applies, the already-typed args are reused in
// the fallback callable_infer without re-evaluating, which would emit
// duplicate errors.
//
// Only do this for "polymorphic callables" — generic functions that actually
// accept a Callable-typed parameter. This avoids unnecessary pre-evaluation
// overhead for generic functions like `identity[T](x: T)` that don't take
// callable arguments, and preserves contextual typing for lambdas passed to
// such functions.
let has_callable_param = match &callable.params {
Params::ParamSpec(..) => true,
Params::List(list) => list
.items()
.iter()
.any(|p| matches!(p.as_type(), Type::Callable(_))),
_ => false,
};
if tparams.is_some() && has_callable_param {
let call = CallWithTypes::new();
let typed_args = call.vec_call_arg(args, self, errors);
let typed_kws = call.vec_call_keyword(keywords, self, errors);
if let Some(result) = self.call_generic_with_overloaded_arg(
callable.clone(),
&metadata,
tparams.as_deref(),
&typed_args,
&typed_kws,
arguments_range,
errors,
hint,
) {
result
} else {
self.callable_infer(
callable,
Some(&metadata.kind),
tparams.as_deref(),
None,
&typed_args,
&typed_kws,
arguments_range,
errors,
errors,
context,
hint,
ctor_targs,
)
}
} else {
self.callable_infer(
callable,
Some(&metadata.kind),
tparams.as_deref(),
None,
args,
keywords,
arguments_range,
errors,
errors,
context,
hint,
ctor_targs,
)
}
}
CallTarget::FunctionOverload(overloads, metadata) => {
self.call_overloads(
overloads,
Expand Down
3 changes: 1 addition & 2 deletions pyrefly/lib/test/callable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1342,8 +1342,7 @@ class Class7(Generic[T]):
pass

r7 = accepts_callable(Class7)
# pyrefly incorrectly errors on these - should be OK
assert_type(r7(""), Class7[str]) # E: assert_type(Class7[int], Class7[str]) failed # E: Argument `Literal['']` is not assignable
assert_type(r7(""), Class7[str])

class Class8(Generic[T]):
def __new__(cls, x: list[T], y: list[T]) -> Self:
Expand Down
2 changes: 1 addition & 1 deletion pyrefly/lib/test/contextual.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def g(x: int, y: str):
x1 = f(g, lambda x, y: None)
reveal_type(x1) # E: revealed type: (x: int, y: str) -> None

x2 = f(g, lambda x, z: None) # E: Argument `(x: int, z: Unknown) -> None` is not assignable to parameter `g` with type `(x: int, y: str) -> None`
x2 = f(g, lambda x, z: None) # E: Argument `(x: Unknown, z: Unknown) -> None` is not assignable to parameter `g` with type `(x: int, y: str) -> None`
reveal_type(x2) # E: revealed type: (x: int, y: str) -> None
"#,
);
Expand Down
Loading