Skip to content

Commit

Permalink
Merge pull request #569 from piotrbartman/dev
Browse files Browse the repository at this point in the history
moments prange + test_moments type fix
  • Loading branch information
slayoo authored Jun 15, 2021
2 parents 5016346 + ba04bed commit 28cb91c
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 11 deletions.
137 changes: 137 additions & 0 deletions PySDM/backends/numba/impl/_atomic_operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""
https://github.com/KatanaGraph/katana/blob/master/python/katana/numba_support/numpy_atomic.py
"""

from numba import types
from numba.core import cgutils
from numba.core.typing.arraydecl import get_array_index_type
from numba.extending import lower_builtin, type_callable
from numba.np.arrayobj import basic_indexing, make_array, normalize_indices

__all__ = ["atomic_add", "atomic_sub", "atomic_max", "atomic_min"]


def atomic_rmw(context, builder, op, arrayty, val, ptr):
assert arrayty.aligned # We probably have to have aligned arrays.
dataval = context.get_value_as_data(builder, arrayty.dtype, val)
return builder.atomic_rmw(op, ptr, dataval, "monotonic")


def declare_atomic_array_op(iop, uop, fop):
def decorator(func):
@type_callable(func)
def func_type(context):
def typer(ary, idx, val):
out = get_array_index_type(ary, idx)
if out is not None:
res = out.result
if context.can_convert(val, res):
return res
return None

return typer

_ = func_type

@lower_builtin(func, types.Buffer, types.Any, types.Any)
def func_impl(context, builder, sig, args):
"""
array[a] = scalar_or_array
array[a,..,b] = scalar_or_array
"""
aryty, idxty, valty = sig.args
ary, idx, val = args

if isinstance(idxty, types.BaseTuple):
index_types = idxty.types
indices = cgutils.unpack_tuple(builder, idx, count=len(idxty))
else:
index_types = (idxty,)
indices = (idx,)

ary = make_array(aryty)(context, builder, ary)

# First try basic indexing to see if a single array location is denoted.
index_types, indices = normalize_indices(context, builder, index_types, indices)
dataptr, shapes, _strides = basic_indexing(
context, builder, aryty, ary, index_types, indices, boundscheck=context.enable_boundscheck,
)
if shapes:
raise NotImplementedError("Complex shapes are not supported")

# Store source value the given location
val = context.cast(builder, val, valty, aryty.dtype)
op = None
if isinstance(aryty.dtype, types.Integer) and aryty.dtype.signed:
op = iop
elif isinstance(aryty.dtype, types.Integer) and not aryty.dtype.signed:
op = uop
elif isinstance(aryty.dtype, types.Float):
op = fop
if op is None:
raise TypeError("Atomic operation not supported on " + str(aryty))
return atomic_rmw(context, builder, op, aryty, val, dataptr)

_ = func_impl

return func

return decorator


@declare_atomic_array_op("add", "add", "fadd")
def atomic_add(ary, i, v):
"""
Atomically, perform `ary[i] += v` and return the previous value of `ary[i]`.
i must be a simple index for a single element of ary. Broadcasting and vector operations are not supported.
This should be used from numba compiled code.
"""
orig = ary[i]
ary[i] += v
return orig


@declare_atomic_array_op("sub", "sub", "fsub")
def atomic_sub(ary, i, v):
"""
Atomically, perform `ary[i] -= v` and return the previous value of `ary[i]`.
i must be a simple index for a single element of ary. Broadcasting and vector operations are not supported.
This should be used from numba compiled code.
"""
orig = ary[i]
ary[i] -= v
return orig


@declare_atomic_array_op("max", "umax", None)
def atomic_max(ary, i, v):
"""
Atomically, perform `ary[i] = max(ary[i], v)` and return the previous value of `ary[i]`.
This operation does not support floating-point values.
i must be a simple index for a single element of ary. Broadcasting and vector operations are not supported.
This should be used from numba compiled code.
"""
orig = ary[i]
ary[i] = max(ary[i], v)
return orig


@declare_atomic_array_op("min", "umin", None)
def atomic_min(ary, i, v):
"""
Atomically, perform `ary[i] = min(ary[i], v)` and return the previous value of `ary[i]`.
This operation does not support floating-point values.
i must be a simple index for a single element of ary. Broadcasting and vector operations are not supported.
This should be used from numba compiled code.
"""
orig = ary[i]
ary[i] = min(ary[i], v)
return orig
12 changes: 8 additions & 4 deletions PySDM/backends/numba/impl/moments_methods.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numba

from PySDM.backends.numba import conf
from PySDM.backends.numba.impl._atomic_operations import atomic_add


class MomentsMethods:
Expand All @@ -10,11 +12,13 @@ def moments_body(
ranks, min_x, max_x, x_attr, weighting_attribute, weighting_rank):
moment_0[:] = 0
moments[:, :] = 0
for i in idx[:length]:
for idx_i in numba.prange(length):
i = idx[idx_i]
if min_x < x_attr[i] < max_x:
moment_0[cell_id[i]] += n[i] * weighting_attribute[i]**weighting_rank
for k in range(ranks.shape[0]): # TODO #401 (AtomicAdd)
moments[k, cell_id[i]] += n[i] * weighting_attribute[i]**weighting_rank * attr_data[i] ** ranks[k]
atomic_add(moment_0, cell_id[i], n[i] * weighting_attribute[i] ** weighting_rank)
for k in range(ranks.shape[0]):
atomic_add(moments, (k, cell_id[i]),
n[i] * weighting_attribute[i] ** weighting_rank * attr_data[i] ** ranks[k])
for c_id in range(moment_0.shape[0]):
for k in range(ranks.shape[0]):
moments[k, c_id] = moments[k, c_id] / moment_0[c_id] if moment_0[c_id] != 0 else 0
Expand Down
12 changes: 9 additions & 3 deletions PySDM_tests/unit_tests/backends/test_algorithmic_methods.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import pytest
import os

import numpy as np
import pytest

from PySDM.backends.numba.impl._algorithmic_methods import pair_indices
from PySDM.storages.index import make_Index
from PySDM.storages.pair_indicator import make_PairIndicator
from PySDM.storages.indexed_storage import make_IndexedStorage
from PySDM.storages.pair_indicator import make_PairIndicator
# noinspection PyUnresolvedReferences
from PySDM_tests.backends_fixture import backend

Expand Down Expand Up @@ -69,6 +71,10 @@ def test_adaptive_sdm_gamma(backend, gamma, idx, n, cell_id, dt_left, dt, dt_max

# Assert
np.testing.assert_array_almost_equal(_dt_left.to_ndarray(), np.asarray(expected_dt_left))
expected_gamma = (dt - np.asarray(expected_dt_left)) / dt * np.asarray(gamma)
expected_gamma = np.empty_like(np.asarray(gamma))
for i in range(len(idx)):
if is_first_in_pair[i]:
expected_gamma[i // 2] = (dt - np.asarray(expected_dt_left[cell_id[i]])) / dt * np.asarray(gamma)[
i // 2]
np.testing.assert_array_almost_equal(_gamma.to_ndarray(), expected_gamma)
np.testing.assert_array_equal(_n_substep, np.asarray(expected_n_substep))
8 changes: 4 additions & 4 deletions PySDM_tests/unit_tests/particles/test_moments.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import numpy as np

# noinspection PyUnresolvedReferences
from PySDM_tests.backends_fixture import backend
from PySDM.initialisation.multiplicities import discretise_n
from PySDM.physics.spectra import Lognormal
from PySDM.initialisation.spectral_sampling import Linear
from PySDM.physics.spectra import Lognormal
# noinspection PyUnresolvedReferences
from PySDM_tests.backends_fixture import backend
from PySDM_tests.unit_tests.dummy_core import DummyCore


Expand All @@ -30,7 +30,7 @@ def test_moment_0d(backend):
true_mean, true_var = spectrum.stats(moments='mv')

# TODO #217 : add a moments_0 wrapper
moment_0 = particles.backend.Storage.empty((1,), dtype=int)
moment_0 = particles.backend.Storage.empty((1,), dtype=float)
moments = particles.backend.Storage.empty((1, 1), dtype=float)

# Act
Expand Down

0 comments on commit 28cb91c

Please sign in to comment.