gfx/angle/checkout/src/compiler/translator/tree_ops/ArrayReturnValueToOutParameter.cpp
author Jeff Gilbert <jgilbert@mozilla.com>
Fri, 15 Mar 2019 22:55:50 -0700
changeset 470378 167ee7c46b84bc9f0988896d74adc810ec2e495a
parent 419438 b7c91a6f1b0a72da63f430d8f6c57d5374d7e0a7
permissions -rw-r--r--
Bug 1520948 - Update ANGLE to chromium/3729..moz/firefox-68. Differential Revision: https://phabricator.services.mozilla.com/D23772

//
// Copyright (c) 2002-2015 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
// The ArrayReturnValueToOutParameter function changes return values of an array type to out
// parameters in function definitions, prototypes, and call sites.

#include "compiler/translator/tree_ops/ArrayReturnValueToOutParameter.h"

#include <map>

#include "compiler/translator/StaticType.h"
#include "compiler/translator/SymbolTable.h"
#include "compiler/translator/tree_util/IntermNode_util.h"
#include "compiler/translator/tree_util/IntermTraverse.h"

namespace sh
{

namespace
{

constexpr const ImmutableString kReturnValueVariableName("angle_return");

class ArrayReturnValueToOutParameterTraverser : private TIntermTraverser
{
  public:
    static void apply(TIntermNode *root, TSymbolTable *symbolTable);

  private:
    ArrayReturnValueToOutParameterTraverser(TSymbolTable *symbolTable);

    void visitFunctionPrototype(TIntermFunctionPrototype *node) override;
    bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
    bool visitAggregate(Visit visit, TIntermAggregate *node) override;
    bool visitBranch(Visit visit, TIntermBranch *node) override;
    bool visitBinary(Visit visit, TIntermBinary *node) override;

    TIntermAggregate *createReplacementCall(TIntermAggregate *originalCall,
                                            TIntermTyped *returnValueTarget);

    // Set when traversal is inside a function with array return value.
    TIntermFunctionDefinition *mFunctionWithArrayReturnValue;

    struct ChangedFunction
    {
        const TVariable *returnValueVariable;
        const TFunction *func;
    };

    // Map from function symbol ids to the changed function.
    std::map<int, ChangedFunction> mChangedFunctions;
};

TIntermAggregate *ArrayReturnValueToOutParameterTraverser::createReplacementCall(
    TIntermAggregate *originalCall,
    TIntermTyped *returnValueTarget)
{
    TIntermSequence *replacementArguments = new TIntermSequence();
    TIntermSequence *originalArguments    = originalCall->getSequence();
    for (auto &arg : *originalArguments)
    {
        replacementArguments->push_back(arg);
    }
    replacementArguments->push_back(returnValueTarget);
    ASSERT(originalCall->getFunction());
    const TSymbolUniqueId &originalId = originalCall->getFunction()->uniqueId();
    TIntermAggregate *replacementCall = TIntermAggregate::CreateFunctionCall(
        *mChangedFunctions[originalId.get()].func, replacementArguments);
    replacementCall->setLine(originalCall->getLine());
    return replacementCall;
}

void ArrayReturnValueToOutParameterTraverser::apply(TIntermNode *root, TSymbolTable *symbolTable)
{
    ArrayReturnValueToOutParameterTraverser arrayReturnValueToOutParam(symbolTable);
    root->traverse(&arrayReturnValueToOutParam);
    arrayReturnValueToOutParam.updateTree();
}

ArrayReturnValueToOutParameterTraverser::ArrayReturnValueToOutParameterTraverser(
    TSymbolTable *symbolTable)
    : TIntermTraverser(true, false, true, symbolTable), mFunctionWithArrayReturnValue(nullptr)
{}

bool ArrayReturnValueToOutParameterTraverser::visitFunctionDefinition(
    Visit visit,
    TIntermFunctionDefinition *node)
{
    if (node->getFunctionPrototype()->isArray() && visit == PreVisit)
    {
        // Replacing the function header is done on visitFunctionPrototype().
        mFunctionWithArrayReturnValue = node;
    }
    if (visit == PostVisit)
    {
        mFunctionWithArrayReturnValue = nullptr;
    }
    return true;
}

void ArrayReturnValueToOutParameterTraverser::visitFunctionPrototype(TIntermFunctionPrototype *node)
{
    if (node->isArray())
    {
        // Replace the whole prototype node with another node that has the out parameter
        // added. Also set the function to return void.
        const TSymbolUniqueId &functionId = node->getFunction()->uniqueId();
        if (mChangedFunctions.find(functionId.get()) == mChangedFunctions.end())
        {
            TType *returnValueVariableType = new TType(node->getType());
            returnValueVariableType->setQualifier(EvqOut);
            ChangedFunction changedFunction;
            changedFunction.returnValueVariable =
                new TVariable(mSymbolTable, kReturnValueVariableName, returnValueVariableType,
                              SymbolType::AngleInternal);
            TFunction *func = new TFunction(mSymbolTable, node->getFunction()->name(),
                                            node->getFunction()->symbolType(),
                                            StaticType::GetBasic<EbtVoid>(), false);
            for (size_t i = 0; i < node->getFunction()->getParamCount(); ++i)
            {
                func->addParameter(node->getFunction()->getParam(i));
            }
            func->addParameter(changedFunction.returnValueVariable);
            changedFunction.func                = func;
            mChangedFunctions[functionId.get()] = changedFunction;
        }
        TIntermFunctionPrototype *replacement =
            new TIntermFunctionPrototype(mChangedFunctions[functionId.get()].func);
        replacement->setLine(node->getLine());

        queueReplacement(replacement, OriginalNode::IS_DROPPED);
    }
}

bool ArrayReturnValueToOutParameterTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
{
    ASSERT(!node->isArray() || node->getOp() != EOpCallInternalRawFunction);
    if (visit == PreVisit && node->isArray() && node->getOp() == EOpCallFunctionInAST)
    {
        // Handle call sites where the returned array is not assigned.
        // Examples where f() is a function returning an array:
        // 1. f();
        // 2. another_array == f();
        // 3. another_function(f());
        // 4. return f();
        // Cases 2 to 4 are already converted to simpler cases by
        // SeparateExpressionsReturningArrays, so we only need to worry about the case where a
        // function call returning an array forms an expression by itself.
        TIntermBlock *parentBlock = getParentNode()->getAsBlock();
        if (parentBlock)
        {
            // replace
            //   f();
            // with
            //   type s0[size]; f(s0);
            TIntermSequence replacements;

            // type s0[size];
            TIntermDeclaration *returnValueDeclaration = nullptr;
            TVariable *returnValue = DeclareTempVariable(mSymbolTable, new TType(node->getType()),
                                                         EvqTemporary, &returnValueDeclaration);
            replacements.push_back(returnValueDeclaration);

            // f(s0);
            TIntermSymbol *returnValueSymbol = CreateTempSymbolNode(returnValue);
            replacements.push_back(createReplacementCall(node, returnValueSymbol));
            mMultiReplacements.push_back(
                NodeReplaceWithMultipleEntry(parentBlock, node, replacements));
        }
        return false;
    }
    return true;
}

bool ArrayReturnValueToOutParameterTraverser::visitBranch(Visit visit, TIntermBranch *node)
{
    if (mFunctionWithArrayReturnValue && node->getFlowOp() == EOpReturn)
    {
        // Instead of returning a value, assign to the out parameter and then return.
        TIntermSequence replacements;

        TIntermTyped *expression = node->getExpression();
        ASSERT(expression != nullptr);
        const TSymbolUniqueId &functionId =
            mFunctionWithArrayReturnValue->getFunction()->uniqueId();
        ASSERT(mChangedFunctions.find(functionId.get()) != mChangedFunctions.end());
        TIntermSymbol *returnValueSymbol =
            new TIntermSymbol(mChangedFunctions[functionId.get()].returnValueVariable);
        TIntermBinary *replacementAssignment =
            new TIntermBinary(EOpAssign, returnValueSymbol, expression);
        replacementAssignment->setLine(expression->getLine());
        replacements.push_back(replacementAssignment);

        TIntermBranch *replacementBranch = new TIntermBranch(EOpReturn, nullptr);
        replacementBranch->setLine(node->getLine());
        replacements.push_back(replacementBranch);

        mMultiReplacements.push_back(
            NodeReplaceWithMultipleEntry(getParentNode()->getAsBlock(), node, replacements));
    }
    return false;
}

bool ArrayReturnValueToOutParameterTraverser::visitBinary(Visit visit, TIntermBinary *node)
{
    if (node->getOp() == EOpAssign && node->getLeft()->isArray())
    {
        TIntermAggregate *rightAgg = node->getRight()->getAsAggregate();
        ASSERT(rightAgg == nullptr || rightAgg->getOp() != EOpCallInternalRawFunction);
        if (rightAgg != nullptr && rightAgg->getOp() == EOpCallFunctionInAST)
        {
            TIntermAggregate *replacementCall = createReplacementCall(rightAgg, node->getLeft());
            queueReplacement(replacementCall, OriginalNode::IS_DROPPED);
        }
    }
    return false;
}

}  // namespace

void ArrayReturnValueToOutParameter(TIntermNode *root, TSymbolTable *symbolTable)
{
    ArrayReturnValueToOutParameterTraverser::apply(root, symbolTable);
}

}  // namespace sh