from internal.file import _gz_errcheck from internal.gc import sizeof, atomic def pickle[T](x: T, jar: Jar): x.__pickle__(jar) def unpickle[T](jar: Jar): return T.__unpickle__(jar) def dump[T](x: T, f): x.__pickle__(f.fp) def load[T](f) -> T: return T.__unpickle__(f.fp) def _write_raw(jar: Jar, p: cobj, n: int): LIMIT = 0x7fffffff while n > 0: b = n if n < LIMIT else LIMIT status = int(_C.gzwrite(jar, p, u32(b))) if status != b: _gz_errcheck(jar) raise IOError("pickle error: gzwrite returned " + str(status)) p += b n -= b def _read_raw(jar: Jar, p: cobj, n: int): LIMIT = 0x7fffffff while n > 0: b = n if n < LIMIT else LIMIT status = int(_C.gzread(jar, p, u32(b))) if status != b: _gz_errcheck(jar) raise IOError("pickle error: gzread returned " + str(status)) p += b n -= b def _write[T](jar: Jar, x: T): y = __ptr__(x) _write_raw(jar, y.as_byte(), sizeof(T)) def _read[T](jar: Jar): x = T() y = __ptr__(x) _read_raw(jar, y.as_byte(), sizeof(T)) return x # Extend core types to allow pickling @extend class int: def __pickle__(self, jar: Jar): _write(jar, self) def __unpickle__(jar: Jar): return _read(jar, int) @extend class float: def __pickle__(self, jar: Jar): _write(jar, self) def __unpickle__(jar: Jar): return _read(jar, float) @extend class bool: def __pickle__(self, jar: Jar): _write(jar, self) def __unpickle__(jar: Jar): return _read(jar, bool) @extend class byte: def __pickle__(self, jar: Jar): _write(jar, self) def __unpickle__(jar: Jar): return _read(jar, byte) @extend class str: def __pickle__(self, jar: Jar): _write(jar, self.len) _write_raw(jar, self.ptr, self.len) def __unpickle__(jar: Jar): n = _read(jar, int) p = Ptr[byte](n) _read_raw(jar, p, n) return str(p, n) @extend class List: def __pickle__(self, jar: Jar): n = len(self) pickle(n, jar) if atomic(T): _write_raw(jar, (self.arr.ptr).as_byte(), n * sizeof(T)) else: for i in range(n): pickle(self.arr[i], jar) def __unpickle__(jar: Jar): n = unpickle(jar, int) arr = Array[T](n) if atomic(T): _read_raw(jar, (arr.ptr).as_byte(), n * sizeof(T)) else: for i in range(n): arr[i] = unpickle(jar, T) return List[T](arr, n) @extend class Dict: def __pickle__(self, jar: Jar): import internal.khash as khash if atomic(K) and atomic(V): pickle(self._n_buckets, jar) pickle(self._size, jar) pickle(self._n_occupied, jar) pickle(self._upper_bound, jar) fsize = khash.__ac_fsize(self._n_buckets) if self._n_buckets > 0 else 0 _write_raw(jar, self._flags.as_byte(), fsize * sizeof(u32)) _write_raw(jar, self._keys.as_byte(), self._n_buckets * sizeof(K)) _write_raw(jar, self._vals.as_byte(), self._n_buckets * sizeof(V)) else: pickle(self._n_buckets, jar) size = len(self) pickle(size, jar) for k,v in self.items(): pickle(k, jar) pickle(v, jar) def __unpickle__(jar: Jar): import internal.khash as khash d = Dict[K,V]() if atomic(K) and atomic(V): n_buckets = unpickle(jar, int) size = unpickle(jar, int) n_occupied = unpickle(jar, int) upper_bound = unpickle(jar, int) fsize = khash.__ac_fsize(n_buckets) if n_buckets > 0 else 0 flags = Ptr[u32](fsize) keys = Ptr[K](n_buckets) vals = Ptr[V](n_buckets) _read_raw(jar, flags.as_byte(), fsize * sizeof(u32)) _read_raw(jar, keys.as_byte(), n_buckets * sizeof(K)) _read_raw(jar, vals.as_byte(), n_buckets * sizeof(V)) d._n_buckets = n_buckets d._size = size d._n_occupied = n_occupied d._upper_bound = upper_bound d._flags = flags d._keys = keys d._vals = vals else: n_buckets = unpickle(jar, int) size = unpickle(jar, int) d.resize(n_buckets) i = 0 while i < size: k = unpickle(jar, K) v = unpickle(jar, V) d[k] = v i += 1 return d @extend class Set: def __pickle__(self, jar: Jar): import internal.khash as khash if atomic(K): pickle(self._n_buckets, jar) pickle(self._size, jar) pickle(self._n_occupied, jar) pickle(self._upper_bound, jar) fsize = khash.__ac_fsize(self._n_buckets) if self._n_buckets > 0 else 0 _write_raw(jar, self._flags.as_byte(), fsize * sizeof(u32)) _write_raw(jar, self._keys.as_byte(), self._n_buckets * sizeof(K)) else: pickle(self._n_buckets, jar) size = len(self) pickle(size, jar) for k in self: pickle(k, jar) def __unpickle__(jar: Jar): import internal.khash as khash s = Set[K]() if atomic(K): n_buckets = unpickle(jar, int) size = unpickle(jar, int) n_occupied = unpickle(jar, int) upper_bound = unpickle(jar, int) fsize = khash.__ac_fsize(n_buckets) if n_buckets > 0 else 0 flags = Ptr[u32](fsize) keys = Ptr[K](n_buckets) _read_raw(jar, flags.as_byte(), fsize * sizeof(u32)) _read_raw(jar, keys.as_byte(), n_buckets * sizeof(K)) s._n_buckets = n_buckets s._size = size s._n_occupied = n_occupied s._upper_bound = upper_bound s._flags = flags s._keys = keys else: n_buckets = unpickle(jar, int) size = unpickle(jar, int) s.resize(n_buckets) i = 0 while i < size: k = unpickle(jar, K) s.add(k) i += 1 return s