Edit on GitHub

sqlglot.optimizer.pushdown_projections

  1from collections import defaultdict
  2
  3from sqlglot import alias, exp
  4from sqlglot.optimizer.qualify_columns import Resolver
  5from sqlglot.optimizer.scope import Scope, traverse_scope
  6from sqlglot.schema import ensure_schema
  7from sqlglot.errors import OptimizeError
  8
  9# Sentinel value that means an outer query selecting ALL columns
 10SELECT_ALL = object()
 11
 12
 13# Selection to use if selection list is empty
 14def default_selection(is_agg: bool) -> exp.Alias:
 15    return alias(exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_")
 16
 17
 18def pushdown_projections(expression, schema=None, remove_unused_selections=True):
 19    """
 20    Rewrite sqlglot AST to remove unused columns projections.
 21
 22    Example:
 23        >>> import sqlglot
 24        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
 25        >>> expression = sqlglot.parse_one(sql)
 26        >>> pushdown_projections(expression).sql()
 27        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
 28
 29    Args:
 30        expression (sqlglot.Expression): expression to optimize
 31        remove_unused_selections (bool): remove selects that are unused
 32    Returns:
 33        sqlglot.Expression: optimized expression
 34    """
 35    # Map of Scope to all columns being selected by outer queries.
 36    schema = ensure_schema(schema)
 37    source_column_alias_count = {}
 38    referenced_columns = defaultdict(set)
 39
 40    # We build the scope tree (which is traversed in DFS postorder), then iterate
 41    # over the result in reverse order. This should ensure that the set of selected
 42    # columns for a particular scope are completely build by the time we get to it.
 43    for scope in reversed(traverse_scope(expression)):
 44        parent_selections = referenced_columns.get(scope, {SELECT_ALL})
 45        alias_count = source_column_alias_count.get(scope, 0)
 46
 47        # We can't remove columns SELECT DISTINCT nor UNION DISTINCT.
 48        if scope.expression.args.get("distinct"):
 49            parent_selections = {SELECT_ALL}
 50
 51        if isinstance(scope.expression, exp.SetOperation):
 52            left, right = scope.union_scopes
 53            if len(left.expression.selects) != len(right.expression.selects):
 54                scope_sql = scope.expression.sql()
 55                raise OptimizeError(f"Invalid set operation due to column mismatch: {scope_sql}.")
 56
 57            referenced_columns[left] = parent_selections
 58
 59            if any(select.is_star for select in right.expression.selects):
 60                referenced_columns[right] = parent_selections
 61            elif not any(select.is_star for select in left.expression.selects):
 62                if scope.expression.args.get("by_name"):
 63                    referenced_columns[right] = referenced_columns[left]
 64                else:
 65                    referenced_columns[right] = [
 66                        right.expression.selects[i].alias_or_name
 67                        for i, select in enumerate(left.expression.selects)
 68                        if SELECT_ALL in parent_selections
 69                        or select.alias_or_name in parent_selections
 70                    ]
 71
 72        if isinstance(scope.expression, exp.Select):
 73            if remove_unused_selections:
 74                _remove_unused_selections(scope, parent_selections, schema, alias_count)
 75
 76            if scope.expression.is_star:
 77                continue
 78
 79            # Group columns by source name
 80            selects = defaultdict(set)
 81            for col in scope.columns:
 82                table_name = col.table
 83                col_name = col.name
 84                selects[table_name].add(col_name)
 85
 86            # Push the selected columns down to the next scope
 87            for name, (node, source) in scope.selected_sources.items():
 88                if isinstance(source, Scope):
 89                    columns = {SELECT_ALL} if scope.pivots else selects.get(name) or set()
 90                    referenced_columns[source].update(columns)
 91
 92                column_aliases = node.alias_column_names
 93                if column_aliases:
 94                    source_column_alias_count[source] = len(column_aliases)
 95
 96    return expression
 97
 98
 99def _remove_unused_selections(scope, parent_selections, schema, alias_count):
100    order = scope.expression.args.get("order")
101
102    if order:
103        # Assume columns without a qualified table are references to output columns
104        order_refs = {c.name for c in order.find_all(exp.Column) if not c.table}
105    else:
106        order_refs = set()
107
108    new_selections = []
109    removed = False
110    star = False
111    is_agg = False
112
113    select_all = SELECT_ALL in parent_selections
114
115    for selection in scope.expression.selects:
116        name = selection.alias_or_name
117
118        if select_all or name in parent_selections or name in order_refs or alias_count > 0:
119            new_selections.append(selection)
120            alias_count -= 1
121        else:
122            if selection.is_star:
123                star = True
124            removed = True
125
126        if not is_agg and selection.find(exp.AggFunc):
127            is_agg = True
128
129    if star:
130        resolver = Resolver(scope, schema)
131        names = {s.alias_or_name for s in new_selections}
132
133        for name in sorted(parent_selections):
134            if name not in names:
135                new_selections.append(
136                    alias(exp.column(name, table=resolver.get_table(name)), name, copy=False)
137                )
138
139    # If there are no remaining selections, just select a single constant
140    if not new_selections:
141        new_selections.append(default_selection(is_agg))
142
143    scope.expression.select(*new_selections, append=False, copy=False)
144
145    if removed:
146        scope.clear_cache()
SELECT_ALL = <object object>
def default_selection(is_agg: bool) -> sqlglot.expressions.Alias:
15def default_selection(is_agg: bool) -> exp.Alias:
16    return alias(exp.Max(this=exp.Literal.number(1)) if is_agg else "1", "_")
def pushdown_projections(expression, schema=None, remove_unused_selections=True):
19def pushdown_projections(expression, schema=None, remove_unused_selections=True):
20    """
21    Rewrite sqlglot AST to remove unused columns projections.
22
23    Example:
24        >>> import sqlglot
25        >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
26        >>> expression = sqlglot.parse_one(sql)
27        >>> pushdown_projections(expression).sql()
28        'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
29
30    Args:
31        expression (sqlglot.Expression): expression to optimize
32        remove_unused_selections (bool): remove selects that are unused
33    Returns:
34        sqlglot.Expression: optimized expression
35    """
36    # Map of Scope to all columns being selected by outer queries.
37    schema = ensure_schema(schema)
38    source_column_alias_count = {}
39    referenced_columns = defaultdict(set)
40
41    # We build the scope tree (which is traversed in DFS postorder), then iterate
42    # over the result in reverse order. This should ensure that the set of selected
43    # columns for a particular scope are completely build by the time we get to it.
44    for scope in reversed(traverse_scope(expression)):
45        parent_selections = referenced_columns.get(scope, {SELECT_ALL})
46        alias_count = source_column_alias_count.get(scope, 0)
47
48        # We can't remove columns SELECT DISTINCT nor UNION DISTINCT.
49        if scope.expression.args.get("distinct"):
50            parent_selections = {SELECT_ALL}
51
52        if isinstance(scope.expression, exp.SetOperation):
53            left, right = scope.union_scopes
54            if len(left.expression.selects) != len(right.expression.selects):
55                scope_sql = scope.expression.sql()
56                raise OptimizeError(f"Invalid set operation due to column mismatch: {scope_sql}.")
57
58            referenced_columns[left] = parent_selections
59
60            if any(select.is_star for select in right.expression.selects):
61                referenced_columns[right] = parent_selections
62            elif not any(select.is_star for select in left.expression.selects):
63                if scope.expression.args.get("by_name"):
64                    referenced_columns[right] = referenced_columns[left]
65                else:
66                    referenced_columns[right] = [
67                        right.expression.selects[i].alias_or_name
68                        for i, select in enumerate(left.expression.selects)
69                        if SELECT_ALL in parent_selections
70                        or select.alias_or_name in parent_selections
71                    ]
72
73        if isinstance(scope.expression, exp.Select):
74            if remove_unused_selections:
75                _remove_unused_selections(scope, parent_selections, schema, alias_count)
76
77            if scope.expression.is_star:
78                continue
79
80            # Group columns by source name
81            selects = defaultdict(set)
82            for col in scope.columns:
83                table_name = col.table
84                col_name = col.name
85                selects[table_name].add(col_name)
86
87            # Push the selected columns down to the next scope
88            for name, (node, source) in scope.selected_sources.items():
89                if isinstance(source, Scope):
90                    columns = {SELECT_ALL} if scope.pivots else selects.get(name) or set()
91                    referenced_columns[source].update(columns)
92
93                column_aliases = node.alias_column_names
94                if column_aliases:
95                    source_column_alias_count[source] = len(column_aliases)
96
97    return expression

Rewrite sqlglot AST to remove unused columns projections.

Example:
>>> import sqlglot
>>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
>>> expression = sqlglot.parse_one(sql)
>>> pushdown_projections(expression).sql()
'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
Arguments:
  • expression (sqlglot.Expression): expression to optimize
  • remove_unused_selections (bool): remove selects that are unused
Returns:

sqlglot.Expression: optimized expression