@@ -27,7 +27,9 @@ unsafe extern "C" {
2727mod decl {
2828 use crate :: {
2929 AsObject , Py , PyObjectRef , PyResult , VirtualMachine ,
30- builtins:: { PyStrRef , PyTypeRef , PyUtf8StrRef } ,
30+ builtins:: { PyStrRef , PyTypeRef } ,
31+ common:: wtf8:: Wtf8Buf ,
32+ convert:: ToPyObject ,
3133 function:: { Either , FuncArgs , OptionalArg } ,
3234 types:: { PyStructSequence , struct_sequence_new} ,
3335 } ;
@@ -85,13 +87,19 @@ mod decl {
8587
8688 #[ pyfunction]
8789 fn sleep ( seconds : PyObjectRef , vm : & VirtualMachine ) -> PyResult < ( ) > {
90+ let seconds_type_name = seconds. clone ( ) . class ( ) . name ( ) . to_owned ( ) ;
8891 let dur = seconds. try_into_value :: < Duration > ( vm) . map_err ( |e| {
8992 if e. class ( ) . is ( vm. ctx . exceptions . value_error )
9093 && let Some ( s) = e. args ( ) . first ( ) . and_then ( |arg| arg. str ( vm) . ok ( ) )
9194 && s. as_str ( ) == "negative duration"
9295 {
9396 return vm. new_value_error ( "sleep length must be non-negative" ) ;
9497 }
98+ if e. class ( ) . is ( vm. ctx . exceptions . type_error ) {
99+ return vm. new_type_error ( format ! (
100+ "'{seconds_type_name}' object cannot be interpreted as an integer or float"
101+ ) ) ;
102+ }
95103 e
96104 } ) ?;
97105
@@ -575,17 +583,9 @@ mod decl {
575583 }
576584
577585 #[ pyfunction]
578- fn strftime (
579- format : PyUtf8StrRef ,
580- t : OptionalArg < StructTimeData > ,
581- vm : & VirtualMachine ,
582- ) -> PyResult {
586+ fn strftime ( format : PyStrRef , t : OptionalArg < StructTimeData > , vm : & VirtualMachine ) -> PyResult {
583587 #[ cfg( unix) ]
584588 {
585- if format. as_str ( ) . contains ( '\0' ) {
586- return Err ( vm. new_value_error ( "embedded null character" ) ) ;
587- }
588-
589589 let checked_tm = match t {
590590 OptionalArg :: Present ( value) => checked_tm_from_struct_time ( & value, vm, "strftime" ) ?,
591591 OptionalArg :: Missing => {
@@ -598,27 +598,60 @@ mod decl {
598598 let mut tm = checked_tm. tm ;
599599 tm. tm_isdst = tm. tm_isdst . clamp ( -1 , 1 ) ;
600600
601- let fmt = CString :: new ( format. as_str ( ) )
602- . map_err ( |_| vm. new_value_error ( "embedded null character" ) ) ?;
603- let mut size = 1024usize ;
604- let max_scale = 256usize . saturating_mul ( format. as_str ( ) . len ( ) . max ( 1 ) ) ;
605-
606- loop {
607- let mut out = vec ! [ 0u8 ; size] ;
608- let written = unsafe {
609- libc:: strftime (
610- out. as_mut_ptr ( ) . cast ( ) ,
611- out. len ( ) ,
612- fmt. as_ptr ( ) ,
613- & tm as * const libc:: tm ,
614- )
615- } ;
616- if written > 0 || size >= max_scale {
617- let s = String :: from_utf8_lossy ( & out[ ..written] ) . into_owned ( ) ;
618- return Ok ( vm. ctx . new_str ( s) . into ( ) ) ;
601+ fn strftime_ascii ( fmt : & str , tm : & libc:: tm , vm : & VirtualMachine ) -> PyResult < String > {
602+ let fmt_c =
603+ CString :: new ( fmt) . map_err ( |_| vm. new_value_error ( "embedded null character" ) ) ?;
604+ let mut size = 1024usize ;
605+ let max_scale = 256usize . saturating_mul ( fmt. len ( ) . max ( 1 ) ) ;
606+ loop {
607+ let mut out = vec ! [ 0u8 ; size] ;
608+ let written = unsafe {
609+ libc:: strftime (
610+ out. as_mut_ptr ( ) . cast ( ) ,
611+ out. len ( ) ,
612+ fmt_c. as_ptr ( ) ,
613+ tm as * const libc:: tm ,
614+ )
615+ } ;
616+ if written > 0 || size >= max_scale {
617+ return Ok ( String :: from_utf8_lossy ( & out[ ..written] ) . into_owned ( ) ) ;
618+ }
619+ size = size. saturating_mul ( 2 ) ;
620+ }
621+ }
622+
623+ let mut out = Wtf8Buf :: new ( ) ;
624+ let mut ascii = String :: new ( ) ;
625+
626+ for codepoint in format. as_wtf8 ( ) . code_points ( ) {
627+ if codepoint. to_u32 ( ) == 0 {
628+ if !ascii. is_empty ( ) {
629+ let part = strftime_ascii ( & ascii, & tm, vm) ?;
630+ out. extend ( part. chars ( ) ) ;
631+ ascii. clear ( ) ;
632+ }
633+ out. push ( codepoint) ;
634+ continue ;
635+ }
636+ if let Some ( ch) = codepoint. to_char ( )
637+ && ch. is_ascii ( )
638+ {
639+ ascii. push ( ch) ;
640+ continue ;
619641 }
620- size = size. saturating_mul ( 2 ) ;
642+
643+ if !ascii. is_empty ( ) {
644+ let part = strftime_ascii ( & ascii, & tm, vm) ?;
645+ out. extend ( part. chars ( ) ) ;
646+ ascii. clear ( ) ;
647+ }
648+ out. push ( codepoint) ;
649+ }
650+ if !ascii. is_empty ( ) {
651+ let part = strftime_ascii ( & ascii, & tm, vm) ?;
652+ out. extend ( part. chars ( ) ) ;
621653 }
654+ Ok ( out. to_pyobject ( vm) )
622655 }
623656
624657 #[ cfg( not( unix) ) ]
0 commit comments