Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

e2e : add an inner op of the kmeans algorithm #164

Merged
merged 2 commits into from
Jul 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ To test our wrapper, we use two strategies:
- port parts of the numpy test suite
- run several small examples which use NumPy and check that the results are identical to original NumPy.

We only run tests and examples in the eager mode by replacing `import numpy as np` by `import torch_np as np`.
We only run tests in the eager mode by replacing `import numpy as np` by `import torch_np as np`.
Examples we run in both eager and JIT modes.


For numpy tests, see `torch_np/testing/numpy_tests` folder.

Expand All @@ -13,6 +15,42 @@ For numpy tests, see `torch_np/testing/numpy_tests` folder.
- Build a random maze and find a path in it
- Simulate a diffusion/advection process
- Construct and visualize the Mandelbrot fractal
- Inner operation of the k-means clustering

# JIT compiled mode

The main observation is that `torch.dynamo` unrolls python-level loops. For
iterative algorithms this leads to very long compile times. We therefore
often only compile the inner loop.

## Maze path-finding

The Bellman-Ford algorithm simply does not compile because it contains a
data-dependent loop `while point != start`.


## CFD diffusion/advecton process

We compile the inner loop of the diffusion-advection simulation. While the code
compiles, the performance is on par or slightly worse than the original NumPy.

## Mandelbrot fractal

Results strongly depend on an implementation: a straighforward NumPy implementation
uses a data-dependent loop, which does not compile.

The implementation based on the [Mojo benchmark](https://shashankprasanna.com/benchmarking-modular-mojo-and-pytorch-torch.compile-on-mandelbrot-function/index.html#benchmarking-pytorch-cpu-with-torchcompile) allows to compile the inner loop. The performance
increase relative to numpy is substantial and strongly data size and machine
dependent: x8 for smaller inputs and up to x50 for unputs larger than the cache size of the machine.


## K-means clustering

The internal loop of the k-means algorithm compiles into a straighforward
C++ loop and offers up to x30 speedups versus NumPy.


# Eager mode

In short, the main changes to examples are:

Expand Down
58 changes: 58 additions & 0 deletions e2e/kmeans/kmeans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# k-means step, given data, `X` and centroids
# https://realpython.com/numpy-array-programming/#clustering-algorithms
import numpy as np
import torch
torch.set_default_device("cpu")
import torch._dynamo.config as cfg
cfg.numpy_ndarray_as_tensor = True


# np.linalg.norm replacement (2-norm only), https://github.com/pytorch/pytorch/issues/105269
def norm(a, axis):
s = (a.conj() * a).real
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't support complex numbers in inductor, so this could very well be a ** 2

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so pytorch/pytorch#105267 is a wontfix?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#165 has an example of working around not having complex dtypes.

return np.sqrt(s.sum(axis=axis))


#@torch.compile
def get_labels(X, centroids) -> np.ndarray:
return np.argmin(norm(X - centroids[:, None], axis=2),
axis=0)


def init(npts):
np.random.seed(12345)
X = np.repeat([[5, 5], [10, 10]], [npts, npts], axis=0)
X = X + np.random.randn(*X.shape) # 2 distinct "blobs"
centroids = np.array([[5, 5], [10, 10]])
return X, centroids


################ benchmark #####################
import time

# ### numpy ###
npts = int(2e7)
X, centroids = init(npts)

start_time = time.time()
labels = get_labels(X, centroids)
end_time = time.time()
numpy_time = end_time - start_time
print("\n\nnumpy: elapsed=", numpy_time)


# ### compile ###
get_labels_c = torch.compile(get_labels)

# ### warm up ###
for _ in range(5):
get_labels_c(X, centroids)


# ### measure ###
start_time = time.time()
labels = get_labels_c(X, centroids)
end_time = time.time()
compiled_time = end_time - start_time
print("compiled: elapsed=", compiled_time, ' speedup = ', numpy_time / compiled_time)

Loading