diff --git a/tuna/rocmlir/config_type.py b/tuna/rocmlir/config_type.py index f55fe4ac..e4c965d2 100644 --- a/tuna/rocmlir/config_type.py +++ b/tuna/rocmlir/config_type.py @@ -33,6 +33,7 @@ class ConfigType(Enum): convolution: str = 'convolution' gemm: str = 'gemm' attention: str = 'attention' + gemm_gemm: str = 'gemm_gemm' def __str__(self) -> str: return self.value diff --git a/tuna/rocmlir/rocmlir_tables.py b/tuna/rocmlir/rocmlir_tables.py index de29738f..5b427aef 100644 --- a/tuna/rocmlir/rocmlir_tables.py +++ b/tuna/rocmlir/rocmlir_tables.py @@ -783,6 +783,148 @@ class AttentionResults(BASE, ResultsMixin): # pylint: disable=too-many-instance index=True) +class GemmGemmJob(BASE, JobMixin): + """Represents gemm_gemm job table""" + __tablename__ = "rocmlir_gemmgemm_job" + __table_args__ = (UniqueConstraint('config', 'session', name="uq_idx"),) + + config = Column(Integer, + ForeignKey("rocmlir_gemmgemm_config.id"), + nullable=False, + index=True) + + +class GemmGemmConfig(BASE, SimpleCSVMixin): + """Represents GemmGemm config table""" + __tablename__ = "rocmlir_gemmgemm_config" + + data_type = Column(String(length=60), nullable=False, server_default="") + out_data_type = Column(String(length=60), nullable=False, server_default="") + group_size = Column(Integer, nullable=False, server_default="0") + m = Column(Integer, nullable=False, server_default="0") + n = Column(Integer, nullable=False, server_default="0") + k = Column(Integer, nullable=False, server_default="0") + transpose_A = Column(Boolean, nullable=False, server_default="0") + transpose_B = Column(Boolean, nullable=False, server_default="0") + kernel_repeats = Column(Integer, nullable=False, server_default="0") + + def __repr__(self) -> str: + return f"GemmGemmConfig {self.to_dict()}" + + options = { + 'data_type': '-t', + 'out_data_type': '-out_datatype', + 'transpose_A': '-transA', + 'transpose_B': '-transB', + 'group_size': '-g', + 'm': '-m', + 'n': '-n', + 'k': '-k', + 'kernel_repeats': None, + 'id': None, + 'valid': None + } + + def config_string(self): + """Return config as a flag/value string suitable for tuningRunner.py.""" + string = "" + for field, flag in self.options.items(): + value = getattr(self, field, None) + if value is not None and flag is not None: + string += f"{flag} {value} " + return string.strip() + + def parse_line(self, line): + """Parse a command-line-style gemm_gemm config into a GemmGemmConfig object.""" + + print(f"Parsing line {line}") + + i = iter(line.split()) + options = dict(zip(i, i)) + + fields = { + '-transA': 'transpose_A', + '-transB': 'transpose_B', + '-g': 'group_size', + '-m': 'm', + '-n': 'n', + '-k': 'k', + '-t': 'data_type', + '-out_datatype': 'out_data_type' + } + + self.kernel_repeats = 1 + for flag, value in options.items(): + if value in ["true", "True"]: + value = 1 + if value in ["false", "False"]: + value = 0 + field = fields[flag] + if field: + setattr(self, field, value) + + def get_configurations(self, filename): + """Read gemm_gemm-configs from filename and expand into all combinations of + type and transpose. + """ + + DATA_TYPES = ['f32', 'f16', 'i8'] + + configs = [] + with open(filename, 'r', encoding='utf8') as config_file: + lines = config_file.readlines() + + for datatype, transA, transB, line in \ + itertools.product(DATA_TYPES, ['false', 'true'], + ['false', 'true'], lines): + line = line.strip() + + if len(line) == 0 or line[0] == '#': + continue + + dataTypeString = "" + if "-t " not in line: + dataTypeString = f"-t {datatype} " + + transAString = "" + if "-transA " not in line: + transAString = f"-transA {transA} " + + transBString = "" + if "-transB " not in line: + transBString = f"-transB {transB} " + + outDataTypeString = "" + if "-out_datatype" not in line: + outDataTypeString = f"-out_datatype {datatype} " + + one_config = f"{dataTypeString}{outDataTypeString}\ + {transAString}{transBString}{line}".strip() + if one_config not in configs: + configs.append(one_config) + + if "-out_datatype" not in line and datatype == 'i8': + outDataTypeString = "-out_datatype i32 " + one_config = f"{dataTypeString}{outDataTypeString}\ + {transAString}{transBString}{line}".strip() + if one_config not in configs: + configs.append(one_config) + + return configs + + +class GemmGemmResults(BASE, ResultsMixin): # pylint: disable=too-many-instance-attributes + """Collects the results of GemmGemm tuning.""" + + __tablename__ = "rocmlir_gemmgemm_results" + __table_args__ = (UniqueConstraint("config_str", "session", name="uq_idx"),) + + config = Column(Integer, + ForeignKey("rocmlir_gemmgemm_config.id"), + nullable=False, + index=True) + + #pylint: disable=too-few-public-methods class RocMLIRDBTables(DBTablesInterface): """Represents db tables for rocMLIR lib""" @@ -804,6 +946,11 @@ def __init__(self, *, config_type=None, session_id, **kwargs): def set_tables(self, sess_class=None): """Set appropriate tables based on requirements""" + # #region agent log + import json + with open('/home/djoantic/.cursor/debug.log', 'a') as f: + f.write(json.dumps({"timestamp": __import__('time').time()*1000, "location": "rocmlir_tables.py:805", "message": "set_tables called", "data": {"config_type": str(self.config_type), "config_type_value": self.config_type.value if self.config_type else None}, "sessionId": "debug-session", "hypothesisId": "B"}) + '\n') + # #endregion if self.config_type == ConfigType.convolution: self.job_table = ConvolutionJob self.config_table = ConvolutionConfig @@ -816,13 +963,27 @@ def set_tables(self, sess_class=None): self.job_table = AttentionJob self.config_table = AttentionConfig self.results = AttentionResults + elif self.config_type == ConfigType.gemm_gemm: + self.job_table = GemmGemmJob + self.config_table = GemmGemmConfig + self.results = GemmGemmResults else: + # #region agent log + import json + with open('/home/djoantic/.cursor/debug.log', 'a') as f: + f.write(json.dumps({"timestamp": __import__('time').time()*1000, "location": "rocmlir_tables.py:820", "message": "ValueError about to be raised", "data": {"config_type": str(self.config_type), "config_type_value": self.config_type.value if self.config_type else None}, "sessionId": "debug-session", "hypothesisId": "B"}) + '\n') + # #endregion raise ValueError(f"Config type {self.config_type} not yet supported.") def get_tables() -> List[BASE]: """Returns a list of all RocMLIR lib DB tables""" tables: List[BASE] = [] + # #region agent log + import json + with open('/home/djoantic/.cursor/debug.log', 'a') as f: + f.write(json.dumps({"timestamp": __import__('time').time()*1000, "location": "rocmlir_tables.py:823", "message": "get_tables called", "data": {}, "sessionId": "debug-session", "hypothesisId": "C"}) + '\n') + # #endregion with DbSession() as session: engine = session.bind connect = session.connection() @@ -843,7 +1004,15 @@ def append_if_not_exists(table): append_if_not_exists(AttentionConfig()) append_if_not_exists(AttentionJob()) append_if_not_exists(AttentionResults()) - + append_if_not_exists(GemmGemmConfig()) + append_if_not_exists(GemmGemmJob()) + append_if_not_exists(GemmGemmResults()) + + # #region agent log + import json + with open('/home/djoantic/.cursor/debug.log', 'a') as f: + f.write(json.dumps({"timestamp": __import__('time').time()*1000, "location": "rocmlir_tables.py:848", "message": "get_tables returning", "data": {"tables_count": len(tables), "table_names": [t.__tablename__ for t in tables]}, "sessionId": "debug-session", "hypothesisId": "C"}) + '\n') + # #endregion return tables diff --git a/tuna/rocmlir/rocmlir_worker.py b/tuna/rocmlir/rocmlir_worker.py index 8c9ab951..30e31fe6 100644 --- a/tuna/rocmlir/rocmlir_worker.py +++ b/tuna/rocmlir/rocmlir_worker.py @@ -180,6 +180,8 @@ def run_cmd(self): special_args = "--operation gemm" elif self.dbt.config_type == ConfigType.attention: special_args = "--operation attention --verify-mode none" + elif self.dbt.config_type == ConfigType.gemm_gemm: + special_args = "--operation gemm_gemm" else: raise ValueError(f"Config type {self.dbt.config_type} not yet supported.") if self.dbt.session.tuning_space: