@@ -62,6 +62,16 @@ constexpr std::array<LayerAttentionType, kNum> FixedLayerConfig(
6262 return config;
6363}
6464
65+ template <size_t kNum >
66+ constexpr std::array<size_t , kNum > FixedAttentionWindowSizes (
67+ size_t window_size) {
68+ std::array<size_t , kNum > window_size_configs = {};
69+ for (size_t & l : window_size_configs) {
70+ l = window_size;
71+ }
72+ return window_size_configs;
73+ }
74+
6575template <size_t kNumLayers >
6676constexpr size_t NumLayersOfTypeBefore (
6777 const std::array<LayerAttentionType, kNumLayers >& layers,
@@ -114,10 +124,16 @@ template <typename TWeight>
114124struct ConfigGemma27B : public ConfigCapNoSSM {
115125 using Weight = TWeight; // make accessible where we only have a TConfig
116126
117- static constexpr int kSeqLen = gcpp:: kSeqLen ;
127+ static constexpr int kSeqLen = 8192 ;
118128 static constexpr int kVocabSize = 256000 ;
119129 static constexpr std::array<LayerAttentionType, 46 > kLayerConfig =
120130 FixedLayerConfig<46 >(LayerAttentionType::kGemma );
131+ static constexpr std::array<size_t , 46 > kAttentionWindowSizes = {
132+ 4096 , kSeqLen , 4096 , kSeqLen , 4096 , kSeqLen , 4096 , kSeqLen , 4096 , kSeqLen ,
133+ 4096 , kSeqLen , 4096 , kSeqLen , 4096 , kSeqLen , 4096 , kSeqLen , 4096 , kSeqLen ,
134+ 4096 , kSeqLen , 4096 , kSeqLen , 4096 , kSeqLen , 4096 , kSeqLen , 4096 , kSeqLen ,
135+ 4096 , kSeqLen , 4096 , kSeqLen , 4096 , kSeqLen , 4096 , kSeqLen , 4096 , kSeqLen ,
136+ 4096 , kSeqLen , 4096 , kSeqLen , 4096 , kSeqLen };
121137 static constexpr int kLayers = kLayerConfig .size();
122138 static constexpr int kGemmaLayers = kLayers ;
123139 static constexpr int kModelDim = 4608 ;
@@ -134,10 +150,16 @@ template <typename TWeight>
134150struct ConfigGemma9B : public ConfigCapNoSSM {
135151 using Weight = TWeight; // make accessible where we only have a TConfig
136152
137- static constexpr int kSeqLen = gcpp:: kSeqLen ;
153+ static constexpr int kSeqLen = 8192 ;
138154 static constexpr int kVocabSize = 256000 ;
139155 static constexpr std::array<LayerAttentionType, 42 > kLayerConfig =
140156 FixedLayerConfig<42 >(LayerAttentionType::kGemma );
157+ static constexpr std::array<size_t , 42 > kAttentionWindowSizes = {
158+ 4096 , kSeqLen , 4096 , kSeqLen , 4096 , kSeqLen , 4096 , kSeqLen , 4096 , kSeqLen ,
159+ 4096 , kSeqLen , 4096 , kSeqLen , 4096 , kSeqLen , 4096 , kSeqLen , 4096 , kSeqLen ,
160+ 4096 , kSeqLen , 4096 , kSeqLen , 4096 , kSeqLen , 4096 , kSeqLen , 4096 , kSeqLen ,
161+ 4096 , kSeqLen , 4096 , kSeqLen , 4096 , kSeqLen , 4096 , kSeqLen , 4096 , kSeqLen ,
162+ 4096 , kSeqLen };
141163 static constexpr int kLayers = kLayerConfig .size();
142164 static constexpr int kGemmaLayers = kLayers ;
143165 static constexpr int kModelDim = 3584 ;
@@ -158,6 +180,8 @@ struct ConfigGemma7B : public ConfigNoCapNoSSM {
158180 static constexpr int kVocabSize = 256000 ;
159181 static constexpr std::array<LayerAttentionType, 28 > kLayerConfig =
160182 FixedLayerConfig<28 >(LayerAttentionType::kGemma );
183+ static constexpr std::array<size_t , 28 > kAttentionWindowSizes =
184+ FixedAttentionWindowSizes<28 >(kSeqLen );
161185 static constexpr int kLayers = kLayerConfig .size();
162186 static constexpr int kGemmaLayers = kLayers ;
163187 static constexpr int kModelDim = 3072 ;
@@ -178,6 +202,8 @@ struct ConfigGemma2B : public ConfigNoCapNoSSM {
178202 static constexpr int kVocabSize = 256000 ;
179203 static constexpr std::array<LayerAttentionType, 18 > kLayerConfig =
180204 FixedLayerConfig<18 >(LayerAttentionType::kGemma );
205+ static constexpr std::array<size_t , 18 > kAttentionWindowSizes =
206+ FixedAttentionWindowSizes<18 >(kSeqLen );
181207 static constexpr int kLayers = kLayerConfig .size();
182208 static constexpr int kGemmaLayers = kLayers ;
183209 static constexpr int kModelDim = 2048 ;
@@ -198,6 +224,8 @@ struct ConfigGemmaTiny : public ConfigNoSSM {
198224 static constexpr int kVocabSize = 64 ;
199225 static constexpr std::array<LayerAttentionType, 3 > kLayerConfig =
200226 FixedLayerConfig<3 >(LayerAttentionType::kGemma );
227+ static constexpr std::array<size_t , 3 > kAttentionWindowSizes =
228+ FixedAttentionWindowSizes<3 >(kSeqLen );
201229 static constexpr int kLayers = kLayerConfig .size();
202230 static constexpr int kGemmaLayers = kLayers ;
203231 static constexpr int kModelDim = 128 ;
@@ -250,6 +278,8 @@ struct ConfigGriffin2B {
250278 LayerAttentionType::kGriffinRecurrentBlock ,
251279 LayerAttentionType::kGriffinRecurrentBlock ,
252280 };
281+ static constexpr std::array<size_t , 26 > kAttentionWindowSizes =
282+ FixedAttentionWindowSizes<26 >(kSeqLen );
253283 static constexpr int kLayers = kLayerConfig .size();
254284 static constexpr int kGemmaLayers =
255285 NumLayersOfTypeBefore (kLayerConfig , LayerAttentionType::kGemma , kLayers );
0 commit comments