|
5 | 5 | t-SNE is a nonlinear dimensionality reduction algorithm for visualizing |
6 | 6 | high-dimensional data in a low-dimensional space (2D or 3D). |
7 | 7 |
|
8 | | -It computes pairwise similarities in both spaces and minimizes the |
| 8 | +It computes pairwise similarities in both spaces and minimizes the |
9 | 9 | Kullback-Leibler divergence using gradient descent. |
10 | 10 |
|
11 | 11 | References: |
@@ -149,35 +149,33 @@ def apply_tsne( |
149 | 149 | return y |
150 | 150 |
|
151 | 151 |
|
152 | | -def main() -> None: |
| 152 | +def main() -> ndarray: |
153 | 153 | """ |
154 | | - Run t-SNE on Iris dataset and display the first 5 embeddings. |
| 154 | + Run t-SNE on Iris dataset and return the embeddings. |
| 155 | + |
| 156 | + Returns: |
| 157 | + ndarray: t-SNE embedding of the Iris dataset |
155 | 158 |
|
156 | 159 | Example: |
157 | | - >>> main() # doctest: +ELLIPSIS |
158 | | - t-SNE embedding (first 5 points): |
159 | | - [[-... |
| 160 | + >>> result = main() |
| 161 | + >>> result.shape |
| 162 | + (150, 2) |
| 163 | + >>> isinstance(result, np.ndarray) |
| 164 | + True |
160 | 165 | """ |
161 | 166 | data_x, _ = collect_dataset() |
162 | 167 | y_emb = apply_tsne(data_x, n_components=2, n_iter=300) |
163 | 168 |
|
164 | 169 | if not isinstance(y_emb, np.ndarray): |
165 | 170 | raise TypeError("t-SNE embedding must be an ndarray") |
166 | 171 |
|
167 | | - print("t-SNE embedding (first 5 points):") |
168 | | - print(y_emb[:5]) |
169 | | - |
170 | | - # Optional visualization (commented, Ruff/mypy compliant) |
171 | | - # import matplotlib.pyplot as plt |
172 | | - # plt.scatter( |
173 | | - # y_emb[:, 0], |
174 | | - # y_emb[:, 1], |
175 | | - # c=_labels, |
176 | | - # cmap="viridis" |
177 | | - # ) |
178 | | - # plt.show() |
| 172 | + return y_emb |
179 | 173 |
|
180 | 174 |
|
181 | 175 | if __name__ == "__main__": |
182 | 176 | doctest.testmod() |
183 | | - main() |
| 177 | + |
| 178 | + # Demonstration of the algorithm |
| 179 | + result = main() |
| 180 | + print("t-SNE embedding (first 5 points):") |
| 181 | + print(result[:5]) |
0 commit comments