diff --git a/sdks/python/apache_beam/runners/worker/statesampler_fast.pxd b/sdks/python/apache_beam/runners/worker/statesampler_fast.pxd index 2da588effd25..aebf9f69dd86 100644 --- a/sdks/python/apache_beam/runners/worker/statesampler_fast.pxd +++ b/sdks/python/apache_beam/runners/worker/statesampler_fast.pxd @@ -18,12 +18,11 @@ cimport cython from apache_beam.metrics.execution cimport MetricsContainer -from apache_beam.runners.worker.statesampler_interface cimport StateSamplerInterface from cpython cimport pythread from libc.stdint cimport int32_t, int64_t -cdef class StateSampler(StateSamplerInterface): +cdef class StateSampler(object): """Tracks time spent in states during pipeline execution.""" cdef int _sampling_period_ms cdef int _sampling_period_ms_start diff --git a/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx b/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx index 7345ec618121..45700a0b0f81 100644 --- a/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx +++ b/sdks/python/apache_beam/runners/worker/statesampler_fast.pyx @@ -36,7 +36,6 @@ import threading from apache_beam.utils.counters import CounterName from apache_beam.metrics.execution cimport MetricsContainer -from apache_beam.runners.worker.statesampler_interface cimport StateSamplerInterface cimport cython from cpython cimport pythread @@ -68,7 +67,7 @@ cdef inline int64_t get_nsec_time() noexcept nogil: current_time.tv_nsec) -cdef class StateSampler(StateSamplerInterface): +cdef class StateSampler(object): """Tracks time spent in states during pipeline execution.""" def __init__(self, diff --git a/sdks/python/apache_beam/runners/worker/statesampler_stub.py b/sdks/python/apache_beam/runners/worker/statesampler_stub.py index 60cdd0155c0a..563eaed8be43 100644 --- a/sdks/python/apache_beam/runners/worker/statesampler_stub.py +++ b/sdks/python/apache_beam/runners/worker/statesampler_stub.py @@ -20,10 +20,13 @@ class StubStateSampler(StateSamplerInterface): def __init__(self): - self._update_metric_calls = [] + self._update_metric_calls = {} def update_metric(self, typed_metric_name, value): - self._update_metric_calls.append((typed_metric_name, value)) + if (typed_metric_name not in self._update_metric_calls): + self._update_metric_calls[typed_metric_name] = value + return + self._update_metric_calls[typed_metric_name] += value def get_recorded_calls(self): return self._update_metric_calls diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index 0b0d2f823bc3..f75d70edee71 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -2589,7 +2589,7 @@ def __init__(self, pool_submitter_fn, process_fn, timeout): """ Args: pool_submitter_fn (Callable): The process or thread pool submit function. - process_fn (Callable): DoFn#process function to be executed in a + process_fn (Callable): DoFn#process function to be executed in a subprocess or thread. timeout (Optional[float]): The maximum time allowed for execution. """ @@ -2634,7 +2634,9 @@ def submit(self, *args, **kwargs): tracker = get_current_tracker() if tracker is not None: - for typed_metric_name, value in stub_state_sampler.get_recorded_calls(): + for typed_metric_name, value in ( + stub_state_sampler.get_recorded_calls().items() + ): tracker.update_metric(typed_metric_name, value) if results is None: return diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py index 4d54596b6fda..253735be2baa 100644 --- a/sdks/python/apache_beam/transforms/ptransform_test.py +++ b/sdks/python/apache_beam/transforms/ptransform_test.py @@ -2783,10 +2783,15 @@ def test_timeout(self): def test_increment_counter(self): class CounterDoFn(beam.DoFn): def __init__(self): - self.records_counter = Metrics.counter(self.__class__, 'recordsCounter') + self.records_counter1 = Metrics.counter( + self.__class__, 'recordsCounter1') + self.records_counter2 = Metrics.counter( + self.__class__, 'recordsCounter2') def process(self, element): - self.records_counter.inc() + self.records_counter1.inc() + self.records_counter2.inc() + self.records_counter2.inc() yield element with TestPipeline() as p: @@ -2795,12 +2800,18 @@ def process(self, element): .with_exception_handling( use_subprocess=self.use_subprocess, timeout=1)) results = p.result - metric_results = results.metrics().query( - MetricsFilter().with_name("recordsCounter")) - records_counter = metric_results['counters'][0] - self.assertEqual(records_counter.key.metric.name, 'recordsCounter') - self.assertEqual(records_counter.result, 3) + metric_results1 = results.metrics().query( + MetricsFilter().with_name("recordsCounter1")) + records_counter1 = metric_results1['counters'][0] + metric_results2 = results.metrics().query( + MetricsFilter().with_name("recordsCounter2")) + records_counter2 = metric_results2['counters'][0] + + self.assertEqual(records_counter1.key.metric.name, 'recordsCounter1') + self.assertEqual(records_counter1.result, 3) + self.assertEqual(records_counter2.key.metric.name, 'recordsCounter2') + self.assertEqual(records_counter2.result, 6) def test_lifecycle(self): die = type(self).die