Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tuna/rocmlir/config_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class ConfigType(Enum):
convolution: str = 'convolution'
gemm: str = 'gemm'
attention: str = 'attention'
gemm_gemm: str = 'gemm_gemm'

def __str__(self) -> str:
return self.value
171 changes: 170 additions & 1 deletion tuna/rocmlir/rocmlir_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,148 @@ class AttentionResults(BASE, ResultsMixin): # pylint: disable=too-many-instance
index=True)


class GemmGemmJob(BASE, JobMixin):
"""Represents gemm_gemm job table"""
__tablename__ = "rocmlir_gemmgemm_job"
__table_args__ = (UniqueConstraint('config', 'session', name="uq_idx"),)

config = Column(Integer,
ForeignKey("rocmlir_gemmgemm_config.id"),
nullable=False,
index=True)


class GemmGemmConfig(BASE, SimpleCSVMixin):
"""Represents GemmGemm config table"""
__tablename__ = "rocmlir_gemmgemm_config"

data_type = Column(String(length=60), nullable=False, server_default="")
out_data_type = Column(String(length=60), nullable=False, server_default="")
group_size = Column(Integer, nullable=False, server_default="0")
m = Column(Integer, nullable=False, server_default="0")
n = Column(Integer, nullable=False, server_default="0")
k = Column(Integer, nullable=False, server_default="0")
transpose_A = Column(Boolean, nullable=False, server_default="0")
transpose_B = Column(Boolean, nullable=False, server_default="0")
kernel_repeats = Column(Integer, nullable=False, server_default="0")

def __repr__(self) -> str:
return f"GemmGemmConfig {self.to_dict()}"

options = {
'data_type': '-t',
'out_data_type': '-out_datatype',
'transpose_A': '-transA',
'transpose_B': '-transB',
'group_size': '-g',
'm': '-m',
'n': '-n',
'k': '-k',
'kernel_repeats': None,
'id': None,
'valid': None
}

def config_string(self):
"""Return config as a flag/value string suitable for tuningRunner.py."""
string = ""
for field, flag in self.options.items():
value = getattr(self, field, None)
if value is not None and flag is not None:
string += f"{flag} {value} "
return string.strip()

def parse_line(self, line):
"""Parse a command-line-style gemm_gemm config into a GemmGemmConfig object."""

print(f"Parsing line {line}")

i = iter(line.split())
options = dict(zip(i, i))

fields = {
'-transA': 'transpose_A',
'-transB': 'transpose_B',
'-g': 'group_size',
'-m': 'm',
'-n': 'n',
'-k': 'k',
'-t': 'data_type',
'-out_datatype': 'out_data_type'
}

self.kernel_repeats = 1
for flag, value in options.items():
if value in ["true", "True"]:
value = 1
if value in ["false", "False"]:
value = 0
field = fields[flag]
if field:
setattr(self, field, value)

def get_configurations(self, filename):
"""Read gemm_gemm-configs from filename and expand into all combinations of
type and transpose.
"""

DATA_TYPES = ['f32', 'f16', 'i8']

configs = []
with open(filename, 'r', encoding='utf8') as config_file:
lines = config_file.readlines()

for datatype, transA, transB, line in \
itertools.product(DATA_TYPES, ['false', 'true'],
['false', 'true'], lines):
line = line.strip()

if len(line) == 0 or line[0] == '#':
continue

dataTypeString = ""
if "-t " not in line:
dataTypeString = f"-t {datatype} "

transAString = ""
if "-transA " not in line:
transAString = f"-transA {transA} "

transBString = ""
if "-transB " not in line:
transBString = f"-transB {transB} "

outDataTypeString = ""
if "-out_datatype" not in line:
outDataTypeString = f"-out_datatype {datatype} "

one_config = f"{dataTypeString}{outDataTypeString}\
{transAString}{transBString}{line}".strip()
if one_config not in configs:
configs.append(one_config)

if "-out_datatype" not in line and datatype == 'i8':
outDataTypeString = "-out_datatype i32 "
one_config = f"{dataTypeString}{outDataTypeString}\
{transAString}{transBString}{line}".strip()
if one_config not in configs:
configs.append(one_config)

return configs


class GemmGemmResults(BASE, ResultsMixin): # pylint: disable=too-many-instance-attributes
"""Collects the results of GemmGemm tuning."""

__tablename__ = "rocmlir_gemmgemm_results"
__table_args__ = (UniqueConstraint("config_str", "session", name="uq_idx"),)

config = Column(Integer,
ForeignKey("rocmlir_gemmgemm_config.id"),
nullable=False,
index=True)


#pylint: disable=too-few-public-methods
class RocMLIRDBTables(DBTablesInterface):
"""Represents db tables for rocMLIR lib"""
Expand All @@ -804,6 +946,11 @@ def __init__(self, *, config_type=None, session_id, **kwargs):

def set_tables(self, sess_class=None):
"""Set appropriate tables based on requirements"""
# #region agent log
import json
with open('/home/djoantic/.cursor/debug.log', 'a') as f:
f.write(json.dumps({"timestamp": __import__('time').time()*1000, "location": "rocmlir_tables.py:805", "message": "set_tables called", "data": {"config_type": str(self.config_type), "config_type_value": self.config_type.value if self.config_type else None}, "sessionId": "debug-session", "hypothesisId": "B"}) + '\n')
# #endregion
if self.config_type == ConfigType.convolution:
self.job_table = ConvolutionJob
self.config_table = ConvolutionConfig
Expand All @@ -816,13 +963,27 @@ def set_tables(self, sess_class=None):
self.job_table = AttentionJob
self.config_table = AttentionConfig
self.results = AttentionResults
elif self.config_type == ConfigType.gemm_gemm:
self.job_table = GemmGemmJob
self.config_table = GemmGemmConfig
self.results = GemmGemmResults
else:
# #region agent log
import json
with open('/home/djoantic/.cursor/debug.log', 'a') as f:
f.write(json.dumps({"timestamp": __import__('time').time()*1000, "location": "rocmlir_tables.py:820", "message": "ValueError about to be raised", "data": {"config_type": str(self.config_type), "config_type_value": self.config_type.value if self.config_type else None}, "sessionId": "debug-session", "hypothesisId": "B"}) + '\n')
# #endregion
raise ValueError(f"Config type {self.config_type} not yet supported.")


def get_tables() -> List[BASE]:
"""Returns a list of all RocMLIR lib DB tables"""
tables: List[BASE] = []
# #region agent log
import json
with open('/home/djoantic/.cursor/debug.log', 'a') as f:
f.write(json.dumps({"timestamp": __import__('time').time()*1000, "location": "rocmlir_tables.py:823", "message": "get_tables called", "data": {}, "sessionId": "debug-session", "hypothesisId": "C"}) + '\n')
# #endregion
with DbSession() as session:
engine = session.bind
connect = session.connection()
Expand All @@ -843,7 +1004,15 @@ def append_if_not_exists(table):
append_if_not_exists(AttentionConfig())
append_if_not_exists(AttentionJob())
append_if_not_exists(AttentionResults())

append_if_not_exists(GemmGemmConfig())
append_if_not_exists(GemmGemmJob())
append_if_not_exists(GemmGemmResults())

# #region agent log
import json
with open('/home/djoantic/.cursor/debug.log', 'a') as f:
f.write(json.dumps({"timestamp": __import__('time').time()*1000, "location": "rocmlir_tables.py:848", "message": "get_tables returning", "data": {"tables_count": len(tables), "table_names": [t.__tablename__ for t in tables]}, "sessionId": "debug-session", "hypothesisId": "C"}) + '\n')
# #endregion
return tables


Expand Down
2 changes: 2 additions & 0 deletions tuna/rocmlir/rocmlir_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ def run_cmd(self):
special_args = "--operation gemm"
elif self.dbt.config_type == ConfigType.attention:
special_args = "--operation attention --verify-mode none"
elif self.dbt.config_type == ConfigType.gemm_gemm:
special_args = "--operation gemm_gemm"
else:
raise ValueError(f"Config type {self.dbt.config_type} not yet supported.")
if self.dbt.session.tuning_space:
Expand Down