Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Change Logs

* :pr:`422`: add remove_inputs to InputObserver
* :pr:`421`: fix a few patches for MoE
* :pr:`426`: remove MambaCache

0.9.2
+++++
Expand Down
2 changes: 1 addition & 1 deletion _unittests/ut_export/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def forward(self, x, y):

@hide_stdout()
@ignore_warnings(FutureWarning)
@requires_transformers("4.50")
@requires_transformers("4.57")
def test_tiny_llm_to_onnx(self):
import onnxruntime

Expand Down
18 changes: 0 additions & 18 deletions _unittests/ut_helpers/test_cache_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
make_dynamic_cache,
make_encoder_decoder_cache,
make_hybrid_cache,
make_mamba_cache,
make_sliding_window_cache,
make_static_cache,
)
Expand Down Expand Up @@ -150,23 +149,6 @@ def test_unflatten_flatten_encoder_decoder_cache(self):
self.string_type(c2, with_shape=True),
)

@requires_transformers("4.51") # the structure changes
def test_make_mamba_cache(self):
cache = make_mamba_cache(
[
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
]
)
text = self.string_type(cache, with_shape=True)
self.assertEqual(
"MambaCache(conv_states=#3[T1s4x4x4,T1s4x4x4,T1s4x4x4], "
"ssm_states=#3[T1s4x4x4,T1s4x4x4,T1s4x4x4])",
text,
)
self.assertEqual(0, max_diff(cache, cache)["abs"])

@unittest.skipIf(
not make_sliding_window_cache, "SlidingWindowCache removed in transformers>=5"
)
Expand Down
21 changes: 1 addition & 20 deletions _unittests/ut_helpers/test_torch_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import onnx
import torch
import transformers
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, requires_torch
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
from onnx_diagnostic.helpers import max_diff, string_type
from onnx_diagnostic.helpers.torch_helper import (
dummy_llm,
Expand All @@ -22,7 +22,6 @@
from onnx_diagnostic.helpers.cache_helper import (
make_dynamic_cache,
make_encoder_decoder_cache,
make_mamba_cache,
make_sliding_window_cache,
CacheKeyValue,
)
Expand Down Expand Up @@ -313,24 +312,6 @@ def test_torch_deepcopy_cache_dce(self):
self.assertEqual(hash1, hash2)
self.assertGreater(torch_tensor_size(cc), 1)

@requires_torch("4.50")
def test_torch_deepcopy_mamba_cache(self):
cache = make_mamba_cache(
[
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
(torch.rand((4, 4, 4)), torch.rand((4, 4, 4))),
]
)
at = torch_deepcopy(cache)
self.assertEqual(type(cache), type(at))
self.assertEqual(max_diff(cache, at)["abs"], 0)
hash1 = string_type(at, with_shape=True, with_min_max=True)
cache.conv_states[0] += 1000
hash2 = string_type(at, with_shape=True, with_min_max=True)
self.assertEqual(hash1, hash2)
self.assertGreater(torch_tensor_size(cache), 1)

def test_torch_deepcopy_base_model_outputs(self):
bo = transformers.modeling_outputs.BaseModelOutput(
last_hidden_state=torch.rand((4, 4, 4))
Expand Down
18 changes: 1 addition & 17 deletions _unittests/ut_tasks/test_tasks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import unittest
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, has_transformers
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
from onnx_diagnostic.helpers.torch_helper import to_any, torch_deepcopy
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
from onnx_diagnostic.torch_export_patches import torch_export_patches
Expand Down Expand Up @@ -257,22 +257,6 @@ def test_sentence_similary(self):
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
)

@hide_stdout()
def test_falcon_mamba_dev(self):
mid = "tiiuae/falcon-mamba-tiny-dev"
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
self.assertEqual(data["task"], "text-generation")
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
model(**inputs)
model(**data["inputs2"])
self.assertIn((data["size"], data["n_weights"]), [(274958336, 68739584)])
if not has_transformers("5.3.99"):
raise unittest.SkipTest("The model has control flow.")
with torch_export_patches(patch_transformers=True, verbose=10, stop_if_static=1):
torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
)


if __name__ == "__main__":
unittest.main(verbosity=2)
74 changes: 0 additions & 74 deletions _unittests/ut_tasks/try_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,80 +683,6 @@ def mean_pooling(model_output, attention_mask):
print("Sentence embeddings:")
print(sentence_embeddings)

@never_test()
def test_falcon_mamba_dev(self):
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k falcon_mamba_dev
# https://huggingface.co/tiiuae/falcon-mamba-tiny-dev

from transformers import AutoTokenizer
import transformers
import torch

model = "tiiuae/falcon-mamba-tiny-dev"

tokenizer = AutoTokenizer.from_pretrained(model)
pipeline = transformers.pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto",
)
print()
with steal_forward(pipeline.model):
sequences = pipeline(
"Girafatron is obsessed with giraffes, "
"the most glorious animal on the face of this Earth. "
"Giraftron believes all other animals are irrelevant "
"when compared to the glorious majesty of the giraffe."
"\nDaniel: Hello, Girafatron!\nGirafatron:",
max_length=200,
do_sample=True,
top_k=10,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
)
for seq in sequences:
print(f"Result: {seq['generated_text']}")

@never_test()
def test_falcon_mamba_7b(self):
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k falcon_mamba_7b
# https://huggingface.co/tiiuae/falcon-mamba-7b

from transformers import AutoTokenizer
import transformers
import torch

model = "tiiuae/falcon-mamba-7b"

tokenizer = AutoTokenizer.from_pretrained(model)
pipeline = transformers.pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto",
)
print()
with steal_forward(pipeline.model):
sequences = pipeline(
"Girafatron is obsessed with giraffes, "
"the most glorious animal on the face of this Earth. "
"Giraftron believes all other animals are irrelevant "
"when compared to the glorious majesty of the giraffe."
"\nDaniel: Hello, Girafatron!\nGirafatron:",
max_length=200,
do_sample=True,
top_k=10,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
)
for seq in sequences:
print(f"Result: {seq['generated_text']}")

@never_test()
def test_object_detection(self):
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k object_
Expand Down
122 changes: 1 addition & 121 deletions _unittests/ut_torch_export_patches/test_onnx_export_errors.py
Original file line number Diff line number Diff line change
@@ -1,128 +1,8 @@
import unittest
from onnx_diagnostic.ext_test_case import (
ExtTestCase,
requires_torch,
requires_transformers,
skipif_ci_windows,
ignore_warnings,
hide_stdout,
)
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
torch_export_patches,
)
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings


class TestOnnxExportErrors(ExtTestCase):
@requires_transformers("4.49.999")
@skipif_ci_windows("not working on Windows")
@ignore_warnings(UserWarning)
@hide_stdout()
def test_pytree_flatten_mamba_cache(self):
import torch
import torch.utils._pytree as py_pytree

try:
from transformers.models.mamba.modeling_mamba import MambaCache
except ImportError:
from transformers.cache_utils import MambaCache

class _config:
def __init__(self):
self.intermediate_size = 8
self.state_size = 16
self.conv_kernel = 32
self.num_hidden_layers = 64
self.dtype = torch.float16

cache = MambaCache(_config(), max_batch_size=1, device="cpu")

with torch_export_patches(verbose=1):
values, spec = py_pytree.tree_flatten(cache)
cache2 = py_pytree.tree_unflatten(values, spec)
self.assertEqual(cache.max_batch_size, cache2.max_batch_size)
self.assertEqual(cache.intermediate_size, cache2.intermediate_size)
self.assertEqual(cache.ssm_state_size, cache2.ssm_state_size)
self.assertEqual(cache.conv_kernel_size, cache2.conv_kernel_size)
self.assertEqualArrayAny(cache.conv_states, cache2.conv_states)
self.assertEqualArrayAny(cache.ssm_states, cache2.ssm_states)

@requires_transformers("4.50")
@requires_torch("2.7")
@skipif_ci_windows("not working on Windows")
@ignore_warnings(UserWarning)
@hide_stdout()
def test_exportable_mamba_cache(self):
import torch
from transformers.models.mamba.modeling_mamba import MambaCache

class _config:
def __init__(self):
self.intermediate_size = 8
self.state_size = 16
self.conv_kernel = 32
self.num_hidden_layers = 64
self.dtype = torch.float16

class Model(torch.nn.Module):
def forward(self, x: torch.Tensor, cache: MambaCache):
x1 = cache.ssm_states[0] + x
x2 = cache.conv_states[0][:, :, ::2] + x1
return x2

cache = MambaCache(_config(), max_batch_size=1, device="cpu")
# MambaCache was updated in 4.50
self.assertEqual(
"MambaCache(conv_states=#64[T10r3,...], ssm_states=#64[T10r3,...])",
string_type(cache),
)
x = torch.ones(2, 8, 16).to(torch.float16)
model = Model()
model(x, cache)

with torch_export_patches(verbose=1, patch_transformers=True):
cache = MambaCache(_config(), max_batch_size=1, device="cpu")
torch.export.export(Model(), (x, cache))

@requires_transformers("4.49.999")
@skipif_ci_windows("not working on Windows")
@ignore_warnings(UserWarning)
def test_exportable_mamba_cache_dynamic(self):
import torch
from transformers.models.mamba.modeling_mamba import MambaCache

class _config:
def __init__(self):
self.intermediate_size = 8
self.state_size = 16
self.conv_kernel = 32
self.num_hidden_layers = 2
self.dtype = torch.float16

class Model(torch.nn.Module):
def forward(self, x: torch.Tensor, cache: MambaCache):
x1 = cache.ssm_states[0] + x
x2 = cache.conv_states[0][:, :, ::2] + x1
return x2

cache = MambaCache(_config(), max_batch_size=1, device="cpu")
self.assertEqual(
string_type(cache),
"MambaCache(conv_states=#2[T10r3,T10r3], ssm_states=#2[T10r3,T10r3])",
)
x = torch.ones(2, 8, 16).to(torch.float16)
model = Model()
model(x, cache)
DYN = torch.export.Dim.DYNAMIC

with torch_export_patches(patch_transformers=True):
cache = MambaCache(_config(), max_batch_size=2, device="cpu")
torch.export.export(
Model(),
(x, cache),
dynamic_shapes=({0: DYN}, [[{0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}]]),
)

@ignore_warnings(UserWarning)
def test_exportable_dynamic_shapes_constraints(self):
import torch
Expand Down
Loading
Loading