Skip to content

Commit

Permalink
_write_bytes_not_aligned(): use self._io.tell() in EOF check
Browse files Browse the repository at this point in the history
In ced78c4, the EOF check used in _write_bytes_not_aligned() was
changed to use `self.pos()` instead of `self._io.tell()`, but this
turned out to be a mistake. Although it doesn't matter in most cases
because it's usually used via the write_bytes() method which aligns the
stream to a byte boundary first (and `self.pos()` is equivalent to
`self._io.tell()` on a byte boundary), there are few places where it
does matter.

One place is the write_align_to_byte() method, where the
misinterpretation of `self.pos()` would turn into an observable bug if
it weren't for the change in 28de847 (which was done for a different
reason) - we're testing for this bug since
kaitai-io/kaitai_struct_tests@cc36a88.
Another place are the actual _write_bytes_not_aligned() calls in the
write_bits_int_{be,le}() methods (even though I think in the current
implementation the EOF check inside _write_bytes_not_aligned() will not
be triggered).

All in all, it becomes clear that the EOF check in
_write_bytes_not_aligned() must not access `bits_left` - it should work
exclusively with the real byte position of the underlying I/O stream.
  • Loading branch information
generalmimon committed Jul 28, 2023
1 parent 28de847 commit 704995a
Showing 1 changed file with 19 additions and 14 deletions.
33 changes: 19 additions & 14 deletions kaitaistruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,18 +454,13 @@ def bytes_terminate(data, term, include_term):

# region Writing

def _ensure_bytes_left_to_write(self, n):
def _ensure_bytes_left_to_write(self, n, pos):
try:
full_size = self._size
except AttributeError:
raise ValueError("writing to non-seekable streams is not supported")

# Unlike self._io.tell(), pos() respects the `bits_left` field (so it
# will return the stream position as if it were already aligned on a
# byte boundary), which is important when called from write_bits_int_*()
# methods (it ensures that we report the same numbers of bytes here as
# read_bits_int_*() methods would).
num_bytes_left = full_size - self.pos()
num_bytes_left = full_size - pos
if n > num_bytes_left:
raise EOFError(
"requested to write %d bytes, but only %d bytes left in the stream" %
Expand Down Expand Up @@ -609,7 +604,12 @@ def write_bits_int_be(self, n, val):

bits_to_write = self.bits_left + n
bytes_needed = ((bits_to_write - 1) // 8) + 1 # `ceil(bits_to_write / 8)`
self._ensure_bytes_left_to_write(bytes_needed - (1 if self.bits_left > 0 else 0))

# Unlike self._io.tell(), pos() respects the `bits_left` field (it
# returns the stream position as if it were already aligned on a byte
# boundary), which ensures that we report the same numbers of bytes here
# as read_bits_int_*() methods would.
self._ensure_bytes_left_to_write(bytes_needed - (1 if self.bits_left > 0 else 0), self.pos())

bytes_to_write = bits_to_write // 8
self.bits_left = bits_to_write % 8
Expand All @@ -633,13 +633,18 @@ def write_bits_int_le(self, n, val):
self.bits_le = True
self.bits_write_mode = True

bits_needed = self.bits_left + n
bytes_needed = ((bits_needed - 1) // 8) + 1 # `ceil(bits_needed / 8)`
self._ensure_bytes_left_to_write(bytes_needed - (1 if self.bits_left > 0 else 0))
bits_to_write = self.bits_left + n
bytes_needed = ((bits_to_write - 1) // 8) + 1 # `ceil(bits_to_write / 8)`

# Unlike self._io.tell(), pos() respects the `bits_left` field (it
# returns the stream position as if it were already aligned on a byte
# boundary), which ensures that we report the same numbers of bytes here
# as read_bits_int_*() methods would.
self._ensure_bytes_left_to_write(bytes_needed - (1 if self.bits_left > 0 else 0), self.pos())

bytes_to_write = bits_needed // 8
bytes_to_write = bits_to_write // 8
old_bits_left = self.bits_left
self.bits_left = bits_needed % 8
self.bits_left = bits_to_write % 8

if bytes_to_write > 0:
buf = bytearray(bytes_to_write)
Expand Down Expand Up @@ -668,7 +673,7 @@ def write_bytes(self, buf):

def _write_bytes_not_aligned(self, buf):
n = len(buf)
self._ensure_bytes_left_to_write(n)
self._ensure_bytes_left_to_write(n, self._io.tell())
self._io.write(buf)

def write_bytes_limit(self, buf, size, term, pad_byte):
Expand Down

0 comments on commit 704995a

Please sign in to comment.