# Copyright (C) 2022-2024 Exaloop Inc. # Adapted in part from Google's Python re2 wrapper # https://github.com/google/re2/blob/abseil/python/re2.py A = (1 << 0) ASCII = (1 << 0) DEBUG = (1 << 1) I = (1 << 2) IGNORECASE = (1 << 2) L = (1 << 3) LOCALE = (1 << 3) M = (1 << 4) MULTILINE = (1 << 4) S = (1 << 5) DOTALL = (1 << 5) X = (1 << 6) VERBOSE = (1 << 6) _ANCHOR_NONE = 0 _ANCHOR_START = 1 _ANCHOR_BOTH = 2 @tuple class Span: start: int end: int def __bool__(self): return not (self.start == -1 and self.end == -1) @C @pure def seq_re_match(re: cobj, anchor: int, string: str, pos: int, endpos: int) -> Ptr[Span]: pass @C @pure def seq_re_match_one(re: cobj, anchor: int, string: str, pos: int, endpos: int) -> Span: pass @C @pure def seq_re_pattern_groups(re: cobj) -> int: pass @C @pure def seq_re_group_name_to_index(re: cobj, name: str) -> int: pass @C @pure def seq_re_group_index_to_name(re: cobj, index: int) -> str: pass @C @pure def seq_re_pattern_error(re: cobj) -> str: pass @C @pure def seq_re_escape(pattern: str) -> str: pass @C def seq_re_purge() -> None: pass @C @pure def seq_re_compile(pattern: str, flags: int) -> cobj: pass class error(Static[Exception]): pattern: str def __init__(self, message: str = "", pattern: str = ""): super().__init__("re.error", message) self.pattern = pattern @property def msg(self): return self.message @tuple class Pattern: pattern: str flags: int _re: cobj def compile(pattern: str, flags: int = 0): re = seq_re_compile(pattern, flags) err_msg = seq_re_pattern_error(re) if err_msg: raise error(err_msg, pattern) return Pattern(pattern, flags, re) def search(pattern: str, string: str, flags: int = 0): return compile(pattern, flags).search(string) def match(pattern: str, string: str, flags: int = 0): return compile(pattern, flags).match(string) def fullmatch(pattern: str, string: str, flags: int = 0): return compile(pattern, flags).fullmatch(string) def finditer(pattern: str, string: str, flags: int = 0): return compile(pattern, flags).finditer(string) def findall(pattern: str, string: str, flags: int = 0): return compile(pattern, flags).findall(string) def split(pattern: str, string: str, maxsplit: int = 0, flags: int = 0): return compile(pattern, flags).split(string, maxsplit) def sub(pattern: str, repl, string: str, count: int = 0, flags: int = 0): return compile(pattern, flags).sub(repl, string, count) def subn(pattern: str, repl, string: str, count: int = 0, flags: int = 0): return compile(pattern, flags).subn(repl, string, count) def escape(pattern: str): return seq_re_escape(pattern) def purge(): seq_re_purge() @tuple class Match: _spans: Ptr[Span] pos: int endpos: int re: Pattern string: str def _get_group_int(self, g: int, n: int): if not (0 <= g <= n): raise IndexError("no such group") return self._spans[g] def _get_group_str(self, g: str, n: int): return self._get_group_int(seq_re_group_name_to_index(self.re._re, g), n) def _get_group(self, g, n: int): if isinstance(g, int): return self._get_group_int(g, n) elif isinstance(g, str): return self._get_group_str(g, n) else: return self._get_group(g.__index__(), n) def _span_match(self, span: Span): if not span: return None return self.string._slice(span.start, span.end) def _get_match(self, g, n: int): span = self._get_group(g, n) return self._span_match(span) def _group_multi(self, n: int, *args): if staticlen(args) == 1: return (self._get_match(args[0], n),) else: return (self._get_match(args[0], n), *self._group_multi(n, *args[1:])) def group(self, *args): if staticlen(args) == 0: return self._get_match(0, 1).__val__() elif staticlen(args) == 1: return self._get_match(args[0], self.re.groups) else: return self._group_multi(self.re.groups, *args) def __getitem__(self, g): return self._get_match(g, self.re.groups) def start(self, group = 0): return self._get_group(group, self.re.groups).start def end(self, group = 0): return self._get_group(group, self.re.groups).end def span(self, group = 0): start, end = self._get_group(group, self.re.groups) return start, end def _split(template: str): backslash = '\\' pieces = [''] index = template.find(backslash) OCTAL = compile(r'\\[0-7][0-7][0-7]') GROUP = compile(r'\\[1-9][0-9]?|\\g<\w+>') while index != -1: piece, template = template[:index], template[index:] pieces[-1] += piece octal_match = OCTAL.match(template) group_match = GROUP.match(template) if (not octal_match) and group_match: index = group_match.end() piece, template = template[:index], template[index:] pieces.extend((piece, '')) else: index = 2 piece, template = template[:index], template[index:] pieces[-1] += piece index = template.find(backslash) pieces[-1] += template return pieces def _unescape(s: str): r = [] n = len(s) i = 0 while i < n: if s[i] == '\\' and i + 1 < n: c = s[i + 1] if c == 'a': r.append('\a') i += 1 elif c == 'b': r.append('\b') i += 1 elif c == 'f': r.append('\f') i += 1 elif c == 'n': r.append('\n') i += 1 elif c == 'r': r.append('\r') i += 1 elif c == 't': r.append('\t') i += 1 elif c == 'v': r.append('\v') i += 1 elif c == '"': r.append('\"') i += 1 elif c == '\'': r.append('\'') i += 1 elif c == '\\': r.append('\\') i += 1 elif '0' <= c <= '7': k = i + 2 while k < n and k - i <= 4 and '0' <= s[k] <= '7': k += 1 code = int(s[i+1:k], 8) p = Ptr[byte](1) p[0] = byte(code) r.append(str(p, 1)) i = k - 1 elif c.isalpha(): raise error(f"bad escape \\{c} at position {i}") else: r.append(s[i]) else: r.append(s[i]) i += 1 return str.cat(r) def expand(self, template: str): def get_or_empty(s: Optional[str]): return s if s is not None else '' pieces = list(Match._split(template)) INT = compile(r'[+-]?\d+') for index, piece in enumerate(pieces): if not (index % 2): pieces[index] = Match._unescape(piece) else: if len(piece) <= 3: pieces[index] = get_or_empty(self[int(piece[1:])]) else: group = piece[3:-1] if INT.fullmatch(group): pieces[index] = get_or_empty(self[int(group)]) else: pieces[index] = get_or_empty(self[group]) return str.cat(pieces) @property def lastindex(self): max_end = -1 max_group = None for group in range(1, self.re.groups + 1): end = self._spans[group].end if max_end < end: max_end = end max_group = group return max_group @property def lastgroup(self): max_group = self.lastindex if max_group is None: return None return seq_re_group_index_to_name(self.re._re, max_group) def groups(self, default: Optional[str] = None): def get_or_default(item, default): return item if item is not None else default n = self.re.groups return [get_or_default(self._span_match(self._spans[i]), default) for i in range(1, n + 1)] def groupdict(self, default: Optional[str] = None): d = {} for group, index in self.re.groupindex.items(): item = self[index] d[group] = item if item is not None else default return d def __copy__(self): return self def __deepcopy__(self): return self def __bool__(self): return True @extend class Pattern: @property def groups(self): return seq_re_pattern_groups(self._re) @property def groupindex(self): d = {} for i in range(1, self.groups + 1): name = seq_re_group_index_to_name(self._re, i) if name: d[name] = i return d def _match_one(self, anchor: int, string: str, pos: Optional[int], endpos: Optional[int]): posx = 0 if pos is None else max(0, min(pos.__val__(), len(string))) endposx = len(string) if endpos is None else max(0, min(endpos.__val__(), len(string))) if posx > endposx: return None spans = seq_re_match(self._re, anchor, string, posx, endposx) if not spans[0]: return None return Match(spans, posx, endposx, self, string) def _match(self, anchor: int, string: str, pos: Optional[int], endpos: Optional[int]): posx = 0 if pos is None else max(0, min(pos.__val__(), len(string))) endposx = len(string) if endpos is None else max(0, min(endpos.__val__(), len(string))) if posx > endposx: return while True: spans = seq_re_match(self._re, anchor, string, posx, endposx) if not spans[0]: break yield Match(spans, posx, endposx, self, string) if posx == endposx: break elif posx == spans[0][1]: # We matched the empty string at pos and would be stuck, so in order # to make forward progress, increment the bytes offset. posx += 1 else: posx = spans[0][1] def search(self, string: str, pos: Optional[int] = None, endpos: Optional[int] = None): return self._match_one(_ANCHOR_NONE, string, pos, endpos) def match(self, string: str, pos: Optional[int] = None, endpos: Optional[int] = None): return self._match_one(_ANCHOR_START, string, pos, endpos) def fullmatch(self, string: str, pos: Optional[int] = None, endpos: Optional[int] = None): return self._match_one(_ANCHOR_BOTH, string, pos, endpos) def finditer(self, string: str, pos: Optional[int] = None, endpos: Optional[int] = None): return self._match(_ANCHOR_NONE, string, pos, endpos) def findall(self, string: str, pos: Optional[int] = None, endpos: Optional[int] = None): return [m.group() for m in self.finditer(string, pos, endpos)] def _split(self, cb, string: str, maxsplit: int = 0, T: type = str): if maxsplit < 0: return [T(string)], 0 pieces: List[T] = [] end = 0 numsplit = 0 for match in self.finditer(string): if (maxsplit > 0 and numsplit >= maxsplit): break pieces.append(string[end:match.start()]) pieces.extend(cb(match)) end = match.end() numsplit += 1 pieces.append(string[end:]) return pieces, numsplit def split(self, string: str, maxsplit: int = 0): cb = lambda match: [match[group] for group in range(1, self.groups + 1)] pieces, _ = self._split(cb, string, maxsplit, Optional[str]) return pieces def _repl(match, repl): if isinstance(repl, str): return match.expand(repl) else: return repl(match) def subn(self, repl, string: str, count: int = 0): cb = lambda match: [Pattern._repl(match, repl)] pieces, numsplit = self._split(cb, string, count, str) joined_pieces = str.cat(pieces) return joined_pieces, numsplit def sub(self, repl, string: str, count: int = 0): joined_pieces, _ = self.subn(repl, string, count) return joined_pieces def __bool__(self): return True