Skip to content

Commit a9b4a06

Browse files
authored
misc: Add Ruff formatting (abetlen#2148)
* Add Ruff formatting and safe lint baseline * Update changelog for Ruff setup
1 parent 9f661ff commit a9b4a06

22 files changed

+607
-510
lines changed

.github/workflows/lint.yaml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
name: Lint
2+
3+
on:
4+
pull_request:
5+
branches:
6+
- main
7+
push:
8+
branches:
9+
- main
10+
11+
jobs:
12+
ruff:
13+
runs-on: ubuntu-latest
14+
steps:
15+
- uses: actions/checkout@v4
16+
17+
- name: Set up Python
18+
uses: actions/setup-python@v5
19+
with:
20+
python-version: "3.12"
21+
22+
- name: Install Ruff
23+
run: python -m pip install "ruff>=0.15.7"
24+
25+
- name: Lint with Ruff
26+
run: python -m ruff check llama_cpp tests
27+
28+
- name: Check formatting with Ruff
29+
run: python -m ruff format --check llama_cpp tests

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ coverage.xml
6666
*.py,cover
6767
.hypothesis/
6868
.pytest_cache/
69+
.ruff_cache/
6970
cover/
7071

7172
# Translations

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
- chore(dev): Add Ruff-based formatting and a safe lint baseline, and run it in CI for pull requests and pushes to `main`
1011
- fix(ci): Run macOS CI on supported Apple Silicon and Intel runners by @abetlen in #2150
1112
- fix(ci): Use the `hf` CLI instead of the deprecated `huggingface-cli` name in GitHub Actions and docs by @abetlen in #2149
1213

Makefile

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,14 @@ deploy.gh-docs:
6767
test:
6868
python3 -m pytest --full-trace -v
6969

70+
lint:
71+
python3 -m ruff check llama_cpp tests
72+
python3 -m ruff format --check llama_cpp tests
73+
74+
format:
75+
python3 -m ruff check --fix llama_cpp tests
76+
python3 -m ruff format llama_cpp tests
77+
7078
docker:
7179
docker build -t llama-cpp-python:latest -f docker/simple/Dockerfile .
7280

@@ -93,5 +101,7 @@ clean:
93101
build.sdist \
94102
deploy.pypi \
95103
deploy.gh-docs \
104+
lint \
105+
format \
96106
docker \
97107
clean

README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,9 @@ pip install --upgrade pip
752752
# Install with pip
753753
pip install -e .
754754

755+
# install development tooling (tests, docs, ruff)
756+
pip install -e '.[dev]'
757+
755758
# if you want to use the fastapi / openapi server
756759
pip install -e '.[server]'
757760

@@ -768,6 +771,17 @@ Now try running the tests
768771
pytest
769772
```
770773

774+
And check formatting / linting before opening a PR:
775+
776+
```bash
777+
python -m ruff check llama_cpp tests
778+
python -m ruff format --check llama_cpp tests
779+
780+
# or use the Makefile targets
781+
make lint
782+
make format
783+
```
784+
771785
There's a `Makefile` available with useful targets.
772786
A typical workflow would look like this:
773787

llama_cpp/_ggml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
33
This module provides a minimal interface for working with ggml tensors from llama-cpp-python
44
"""
5+
56
import os
67
import pathlib
78

89
import llama_cpp._ctypes_extensions as ctypes_ext
910

1011
libggml_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib"
1112
libggml = ctypes_ext.load_shared_library("ggml", libggml_base_path)
12-

llama_cpp/_internals.py

Lines changed: 60 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,9 @@ def get_embeddings_seq(self, seq_id: int):
355355
# Sampling functions - deprecated, use LlamaSampler instead
356356

357357
def set_rng_seed(self, seed: int):
358-
raise NotImplementedError("set_rng_seed is deprecated, use LlamaSampler instead")
358+
raise NotImplementedError(
359+
"set_rng_seed is deprecated, use LlamaSampler instead"
360+
)
359361

360362
def sample_repetition_penalties(
361363
self,
@@ -366,30 +368,44 @@ def sample_repetition_penalties(
366368
penalty_freq: float,
367369
penalty_present: float,
368370
):
369-
raise NotImplementedError("sample_repetition_penalties is deprecated, use LlamaSampler instead")
371+
raise NotImplementedError(
372+
"sample_repetition_penalties is deprecated, use LlamaSampler instead"
373+
)
370374

371375
def sample_softmax(self, candidates: "_LlamaTokenDataArray"):
372-
raise NotImplementedError("sample_softmax is deprecated, use LlamaSampler instead")
376+
raise NotImplementedError(
377+
"sample_softmax is deprecated, use LlamaSampler instead"
378+
)
373379

374380
def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int):
375-
raise NotImplementedError("sample_top_k is deprecated, use LlamaSampler instead")
381+
raise NotImplementedError(
382+
"sample_top_k is deprecated, use LlamaSampler instead"
383+
)
376384

377385
def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
378-
raise NotImplementedError("sample_top_p is deprecated, use LlamaSampler instead")
386+
raise NotImplementedError(
387+
"sample_top_p is deprecated, use LlamaSampler instead"
388+
)
379389

380390
def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int):
381-
raise NotImplementedError("sample_min_p is deprecated, use LlamaSampler instead")
391+
raise NotImplementedError(
392+
"sample_min_p is deprecated, use LlamaSampler instead"
393+
)
382394

383395
def sample_typical(
384396
self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int
385397
):
386-
raise NotImplementedError("sample_typical is deprecated, use LlamaSampler instead")
398+
raise NotImplementedError(
399+
"sample_typical is deprecated, use LlamaSampler instead"
400+
)
387401

388402
def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float):
389403
raise NotImplementedError("sample_temp is deprecated, use LlamaSampler instead")
390404

391405
def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar):
392-
raise NotImplementedError("sample_grammar is deprecated, use LlamaSampler instead")
406+
raise NotImplementedError(
407+
"sample_grammar is deprecated, use LlamaSampler instead"
408+
)
393409

394410
def sample_token_mirostat(
395411
self,
@@ -399,7 +415,9 @@ def sample_token_mirostat(
399415
m: int,
400416
mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float],
401417
) -> int:
402-
raise NotImplementedError("sample_token_mirostat is deprecated, use LlamaSampler instead")
418+
raise NotImplementedError(
419+
"sample_token_mirostat is deprecated, use LlamaSampler instead"
420+
)
403421

404422
def sample_token_mirostat_v2(
405423
self,
@@ -408,17 +426,25 @@ def sample_token_mirostat_v2(
408426
eta: float,
409427
mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float],
410428
) -> int:
411-
raise NotImplementedError("sample_token_mirostat_v2 is deprecated, use LlamaSampler instead")
429+
raise NotImplementedError(
430+
"sample_token_mirostat_v2 is deprecated, use LlamaSampler instead"
431+
)
412432

413433
def sample_token_greedy(self, candidates: "_LlamaTokenDataArray") -> int:
414-
raise NotImplementedError("sample_token_greedy is deprecated, use LlamaSampler instead")
434+
raise NotImplementedError(
435+
"sample_token_greedy is deprecated, use LlamaSampler instead"
436+
)
415437

416438
def sample_token(self, candidates: "_LlamaTokenDataArray") -> int:
417-
raise NotImplementedError("sample_token is deprecated, use LlamaSampler instead")
439+
raise NotImplementedError(
440+
"sample_token is deprecated, use LlamaSampler instead"
441+
)
418442

419443
# Grammar
420444
def grammar_accept_token(self, grammar: LlamaGrammar, token: int):
421-
raise NotImplementedError("grammar_accept_token is deprecated, use LlamaSampler instead")
445+
raise NotImplementedError(
446+
"grammar_accept_token is deprecated, use LlamaSampler instead"
447+
)
422448

423449
def reset_timings(self):
424450
llama_cpp.llama_perf_context_reset(self.ctx)
@@ -602,16 +628,16 @@ def sample(
602628
logits_array: Optional[npt.NDArray[np.single]] = None,
603629
):
604630
# This method is deprecated in favor of using LlamaSampler directly
605-
raise NotImplementedError("LlamaSamplingContext.sample is deprecated, use LlamaSampler instead")
631+
raise NotImplementedError(
632+
"LlamaSamplingContext.sample is deprecated, use LlamaSampler instead"
633+
)
606634

607635
def accept(self, ctx_main: LlamaContext, id: int, apply_grammar: bool):
608636
self.prev.append(id)
609637

610638

611639
class CustomSampler:
612-
def __init__(
613-
self, apply_func: Callable[[llama_cpp.llama_token_data_array], None]
614-
):
640+
def __init__(self, apply_func: Callable[[llama_cpp.llama_token_data_array], None]):
615641
self.apply_func = apply_func
616642

617643
def apply_wrapper(
@@ -723,28 +749,28 @@ def add_grammar(self, model: LlamaModel, grammar: LlamaGrammar):
723749
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
724750

725751
def add_grammar_lazy_patterns(
726-
self,
727-
model: LlamaModel,
752+
self,
753+
model: LlamaModel,
728754
grammar: LlamaGrammar,
729755
trigger_patterns: List[str],
730-
trigger_tokens: List[int]
756+
trigger_tokens: List[int],
731757
):
732758
# Convert patterns to C array
733759
pattern_ptrs = (ctypes.c_char_p * len(trigger_patterns))()
734760
for i, pattern in enumerate(trigger_patterns):
735761
pattern_ptrs[i] = pattern.encode("utf-8")
736-
762+
737763
# Convert tokens to C array
738764
token_array = (llama_cpp.llama_token * len(trigger_tokens))(*trigger_tokens)
739-
765+
740766
sampler = llama_cpp.llama_sampler_init_grammar_lazy_patterns(
741767
model.vocab,
742768
grammar._grammar.encode("utf-8"),
743769
grammar._root.encode("utf-8"),
744770
pattern_ptrs,
745771
len(trigger_patterns),
746772
token_array,
747-
len(trigger_tokens)
773+
len(trigger_tokens),
748774
)
749775
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
750776

@@ -771,13 +797,13 @@ def add_dry(
771797
dry_base: float,
772798
dry_allowed_length: int,
773799
dry_penalty_last_n: int,
774-
seq_breakers: List[str]
800+
seq_breakers: List[str],
775801
):
776802
# Convert seq_breakers to C array
777803
breaker_ptrs = (ctypes.c_char_p * len(seq_breakers))()
778804
for i, breaker in enumerate(seq_breakers):
779805
breaker_ptrs[i] = breaker.encode("utf-8")
780-
806+
781807
sampler = llama_cpp.llama_sampler_init_dry(
782808
model.vocab,
783809
n_ctx_train,
@@ -786,25 +812,19 @@ def add_dry(
786812
dry_allowed_length,
787813
dry_penalty_last_n,
788814
breaker_ptrs,
789-
len(seq_breakers)
815+
len(seq_breakers),
790816
)
791817
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
792818

793-
def add_logit_bias(
794-
self,
795-
n_vocab: int,
796-
logit_bias: Dict[int, float]
797-
):
819+
def add_logit_bias(self, n_vocab: int, logit_bias: Dict[int, float]):
798820
# Convert logit_bias dict to C array
799821
bias_array = (llama_cpp.llama_logit_bias * len(logit_bias))()
800822
for i, (token, bias) in enumerate(logit_bias.items()):
801823
bias_array[i].token = token
802824
bias_array[i].bias = bias
803-
825+
804826
sampler = llama_cpp.llama_sampler_init_logit_bias(
805-
n_vocab,
806-
len(logit_bias),
807-
bias_array
827+
n_vocab, len(logit_bias), bias_array
808828
)
809829
llama_cpp.llama_sampler_chain_add(self.sampler, sampler)
810830

@@ -838,15 +858,17 @@ def reset(self):
838858
def clone(self):
839859
# NOTE: Custom samplers cannot be cloned due to Python callback limitations
840860
if self.custom_samplers:
841-
raise NotImplementedError("Cannot clone LlamaSampler that contains custom samplers")
842-
861+
raise NotImplementedError(
862+
"Cannot clone LlamaSampler that contains custom samplers"
863+
)
864+
843865
cloned_sampler = llama_cpp.llama_sampler_clone(self.sampler)
844866
# Create a new wrapper around the cloned sampler
845867
new_sampler = LlamaSampler.__new__(LlamaSampler)
846868
new_sampler.sampler = cloned_sampler
847869
new_sampler.custom_samplers = []
848870
new_sampler._exit_stack = ExitStack()
849-
871+
850872
def free_sampler():
851873
if new_sampler.sampler is not None:
852874
llama_cpp.llama_sampler_free(new_sampler.sampler)

llama_cpp/_logger.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
_last_log_level = GGML_LOG_LEVEL_TO_LOGGING_LEVEL[0]
2727

28+
2829
# typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data);
2930
@llama_cpp.llama_log_callback
3031
def llama_log_callback(
@@ -34,7 +35,9 @@ def llama_log_callback(
3435
):
3536
# TODO: Correctly implement continue previous log
3637
global _last_log_level
37-
log_level = GGML_LOG_LEVEL_TO_LOGGING_LEVEL[level] if level != 5 else _last_log_level
38+
log_level = (
39+
GGML_LOG_LEVEL_TO_LOGGING_LEVEL[level] if level != 5 else _last_log_level
40+
)
3841
if logger.level <= GGML_LOG_LEVEL_TO_LOGGING_LEVEL[level]:
3942
print(text.decode("utf-8"), end="", flush=True, file=sys.stderr)
4043
_last_log_level = log_level

0 commit comments

Comments
 (0)