# Copyright (C) 2022-2023 Exaloop Inc. from internal.file import _gz_errcheck from internal.gc import sizeof, atomic def pickle(x: T, jar: Jar, T: type): x.__pickle__(jar) def unpickle(jar: Jar, T: type) -> T: return T.__unpickle__(jar) def dump(x: T, f, T: type): x.__pickle__(f.fp) def load(f, T: type) -> T: return T.__unpickle__(f.fp) def _write_raw(jar: Jar, p: cobj, n: int): LIMIT = 0x7FFFFFFF while n > 0: b = n if n < LIMIT else LIMIT status = int(_C.gzwrite(jar, p, u32(b))) if status != b: _gz_errcheck(jar) raise IOError(f"pickle error: gzwrite returned {status}") p += b n -= b def _read_raw(jar: Jar, p: cobj, n: int): LIMIT = 0x7FFFFFFF while n > 0: b = n if n < LIMIT else LIMIT status = int(_C.gzread(jar, p, u32(b))) if status != b: _gz_errcheck(jar) raise IOError(f"pickle error: gzread returned {status}") p += b n -= b def _write(jar: Jar, x: T, T: type): y = __ptr__(x) _write_raw(jar, y.as_byte(), sizeof(T)) def _read(jar: Jar, T: type) -> T: x = T() y = __ptr__(x) _read_raw(jar, y.as_byte(), sizeof(T)) return x # Extend core types to allow pickling @extend class int: def __pickle__(self, jar: Jar): _write(jar, self) def __unpickle__(jar: Jar) -> int: return _read(jar, int) @extend class Int: def __pickle__(self, jar: Jar): _write(jar, self) def __unpickle__(jar: Jar) -> Int[N]: return _read(jar, Int[N]) @extend class UInt: def __pickle__(self, jar: Jar): _write(jar, self) def __unpickle__(jar: Jar) -> UInt[N]: return _read(jar, UInt[N]) @extend class float: def __pickle__(self, jar: Jar): _write(jar, self) def __unpickle__(jar: Jar) -> float: return _read(jar, float) @extend class float32: def __pickle__(self, jar: Jar): _write(jar, self) def __unpickle__(jar: Jar) -> float32: return _read(jar, float32) @extend class bool: def __pickle__(self, jar: Jar): _write(jar, self) def __unpickle__(jar: Jar) -> bool: return _read(jar, bool) @extend class byte: def __pickle__(self, jar: Jar): _write(jar, self) def __unpickle__(jar: Jar) -> byte: return _read(jar, byte) @extend class str: def __pickle__(self, jar: Jar): _write(jar, self.len) _write_raw(jar, self.ptr, self.len) def __unpickle__(jar: Jar) -> str: n = _read(jar, int) p = Ptr[byte](n) _read_raw(jar, p, n) return str(p, n) @extend class List: def __pickle__(self, jar: Jar): n = len(self) pickle(n, jar) if atomic(T): _write_raw(jar, (self.arr.ptr).as_byte(), n * sizeof(T)) else: for i in range(n): pickle(self.arr[i], jar) def __unpickle__(jar: Jar) -> List[T]: n = unpickle(jar, int) arr = Array[T](n) if atomic(T): _read_raw(jar, (arr.ptr).as_byte(), n * sizeof(T)) else: for i in range(n): arr[i] = unpickle(jar, T) return List[T](arr, n) @extend class DynamicTuple: def __pickle__(self, jar: Jar): n = len(self) pickle(n, jar) if atomic(T): _write_raw(jar, (self._ptr).as_byte(), n * sizeof(T)) else: for i in range(n): pickle(self._ptr[i], jar) def __unpickle__(jar: Jar) -> DynamicTuple[T]: n = unpickle(jar, int) p = Ptr[T](n) if atomic(T): _read_raw(jar, p.as_byte(), n * sizeof(T)) else: for i in range(n): p[i] = unpickle(jar, T) return DynamicTuple[T](p, n) @extend class Dict: def __pickle__(self, jar: Jar): import internal.khash as khash if atomic(K) and atomic(V): pickle(self._n_buckets, jar) pickle(self._size, jar) pickle(self._n_occupied, jar) pickle(self._upper_bound, jar) fsize = khash.__ac_fsize(self._n_buckets) if self._n_buckets > 0 else 0 _write_raw(jar, self._flags.as_byte(), fsize * sizeof(u32)) _write_raw(jar, self._keys.as_byte(), self._n_buckets * sizeof(K)) _write_raw(jar, self._vals.as_byte(), self._n_buckets * sizeof(V)) else: pickle(self._n_buckets, jar) size = len(self) pickle(size, jar) for k, v in self.items(): pickle(k, jar) pickle(v, jar) def __unpickle__(jar: Jar) -> Dict[K, V]: import internal.khash as khash d = {} if atomic(K) and atomic(V): n_buckets = unpickle(jar, int) size = unpickle(jar, int) n_occupied = unpickle(jar, int) upper_bound = unpickle(jar, int) fsize = khash.__ac_fsize(n_buckets) if n_buckets > 0 else 0 flags = Ptr[u32](fsize) keys = Ptr[K](n_buckets) vals = Ptr[V](n_buckets) _read_raw(jar, flags.as_byte(), fsize * sizeof(u32)) _read_raw(jar, keys.as_byte(), n_buckets * sizeof(K)) _read_raw(jar, vals.as_byte(), n_buckets * sizeof(V)) d._n_buckets = n_buckets d._size = size d._n_occupied = n_occupied d._upper_bound = upper_bound d._flags = flags d._keys = keys d._vals = vals else: n_buckets = unpickle(jar, int) size = unpickle(jar, int) d.resize(n_buckets) i = 0 while i < size: k = unpickle(jar, K) v = unpickle(jar, V) d[k] = v i += 1 return d @extend class Set: def __pickle__(self, jar: Jar): import internal.khash as khash if atomic(K): pickle(self._n_buckets, jar) pickle(self._size, jar) pickle(self._n_occupied, jar) pickle(self._upper_bound, jar) fsize = khash.__ac_fsize(self._n_buckets) if self._n_buckets > 0 else 0 _write_raw(jar, self._flags.as_byte(), fsize * sizeof(u32)) _write_raw(jar, self._keys.as_byte(), self._n_buckets * sizeof(K)) else: pickle(self._n_buckets, jar) size = len(self) pickle(size, jar) for k in self: pickle(k, jar) def __unpickle__(jar: Jar) -> Set[K]: import internal.khash as khash s = set[K]() if atomic(K): n_buckets = unpickle(jar, int) size = unpickle(jar, int) n_occupied = unpickle(jar, int) upper_bound = unpickle(jar, int) fsize = khash.__ac_fsize(n_buckets) if n_buckets > 0 else 0 flags = Ptr[u32](fsize) keys = Ptr[K](n_buckets) _read_raw(jar, flags.as_byte(), fsize * sizeof(u32)) _read_raw(jar, keys.as_byte(), n_buckets * sizeof(K)) s._n_buckets = n_buckets s._size = size s._n_occupied = n_occupied s._upper_bound = upper_bound s._flags = flags s._keys = keys else: n_buckets = unpickle(jar, int) size = unpickle(jar, int) s.resize(n_buckets) i = 0 while i < size: k = unpickle(jar, K) s.add(k) i += 1 return s