What?
An attention-based model where an element of a set can be updated with a different set of weights.
Why?
There is declarative and procedural knowledge in the world. The paper gets inspired by cognitive science and wants to show that the declarative/procedural knowledge is a useful split for us.
How?
source: original paper. I was not brave enough to retype that.
I believe, the figure is self explanatory, the authors did an amazing job in making the pseudocode crystal clear. Some comments:
- Step1: Process image by position $p$ with fully convolutional net:
- Turn an image into a set of tokens.
- I did not understand how the position is encoded though.
- Step 2: Soft competition among OFs to select regions of the input to process:
- Each OF node attends to each input nodes and outputs a weighted sum of all the positions (one sum per OF)
- Step 3 OFs pick the post relevant schema and update:
- We need to update the OF nodes, but we can update them with multiple set of weights (each schemata is just a set of weights).
- We try to use everything on everything and then pick the argmax instead of doing a soft attention thingy.
- Gumbel-softmax trick is used to do the backprop.
- Step4: Soft competition among OFs to transmit relevant information to each OF:
- At this point, each OF node has it's state updated by a single set of weights from those available (schematas)
- Now OFs attend to each other and update them as $h_k = h_k + \sum_{k'}s_{k,k'}\hat{v}_{k'}$ ← typical transformer-like update, but we do the addition with STEP 3 output as well.
And?
- I look at the whole thing as a graph with nodes of three types: input nodes, schemata, and OF nodes.
- The most interesting part for me in this paper was using a separate set of weights for processing nodes in the graph. However, you can probably model this as a GNN when the graph is also dynamic (each next layer removes some nodes and adds others):
- In the first layer, we have OF nodes and input vector nodes.
- In the second, when we select the schemata, we have OF nodes and schemata nodes. We can probably model this as a GNN with different edge types (one edge type per schemata and different set of weights for different edge types).
- The last layer is just a typical transformer everything-with-everything update.
- It's probably nice as a motivation, but I believe the paper uses too many programming/cognitive science metaphors. I found them obstructing a clear understanding of the paper. At the same time, I believe, the paper misses a lot of relevant and important literature:
- Object-Oriented MDPs and Relational MDPs.
- General transformer and GNN work in RL. I believe we did a good job discussing that in Section 5 of "My Body is a Cage..."
- Bayesian Multi-Task Reinforcement Learning by Lazaric and Ghavamzadch
- If we forget all the schemata/OF fanciness, SCOFF automatically picks a set of weights to use for an update for different nodes.
- This is also what happens in the model by Lazaric, when there are several classes of value functions to choose from (sampled from a Dirichlet process).
- I don't think the GRU baseline is a fair baseline to use. The authors consider a factored state space → transformer would be the most obvious architecture to compare against. It would not use fancy schematas/OF (→ use a single set of weights to update each nodes), but this would be an interesting ablation to show and reason about. This would also show if the whole schemata/OF metaphors are really useful here. It's not my whim, the authors say that their key contribution is to demonstrate the feasibility and benefit of factorizing declarative knowledge and procedural knowledge. The feasibility has been shown, yes. But the benefit has not been to because we can get a factored state space without declarative knowledge (i.e. with a single schemata).
- I would love to know why we select a single schemata for an update instead of taking a soft combination of outputs of all of them.
- The authors show that when given enough schematas, the model does not necessarily uses all of them. It would be helpful to understand why this happens and how we can reuse this knowledge in other settings.
- I would also like a discussion on how the number of heads in a transformer differ from schematas.
- Another interesting discussion would be comparing the Step 4 update (adding Step3 with the current attention weighted sum) to having a skip-connection before feeding it to a feedforward part of the transformer.