What?
Learning graph-structured state representations for text-based games (TBG) in a data-driven way. Heuristics shall not pass!
Why?
Text-based games is a proxy for studying how agents can exploit language to interact with the environment. They are extremely challenging. Heuristics are often used when approaching text-based games, can we do without?
How?
source: original paper
-
Problem setting:
- TBGs are POMDPS
- TBGs are well-aligned with dynamic graphs since the former can be represented as relational data.
- Nodes are game entities and their conditions (e.g. 'carrots' and 'sliced')
- Edges are relations ('is' or 'north of')
- We need to estimate the game state from raw text observations.
-
Components:
source: original paper
- Graph updater
- The goal of this component is to output a belief graph $\mathcal{G} \in [-1,1]^{R\times N \times N}$, where $R$ is the number of relation types and $N$ is the number of entities.
- The authors assume, that the vocabulary of possible relations and entities is known in advance, and $R =10$, $N=99.$
- Blocks
- $f_d$
- MLP that takes a graph latent representation $h_t$ and outputs $\mathcal{G}$ of fixed size! (as defined above)
- Text encoder
- A transformer encoder that takes the textual observation and action candidate list.
- Graph Encoder
- R-GCN which operates on a complete graph (we do not predict the edges, but rather reals lying in range $[-1,1]$.
- $f_\Delta$
- This is a network that takes text and action representations as well as a graph hidden representation to produce a change in the graph $\Delta g_t$.
- To get the new graph, we feed $\Delta g_t$ to $f_d$.
- This thing is pre-trained in the supervised way of the dataset of demonstrations as well as data collected by a random policy. After that, the weights of the graph updater are frozen.
- Two approaches to pre-train:
- Reconstruct text observation from the belief graph using NLL.
- Differentiate between the representation corresponding to the true ones and the corrupted observations randomly sampled from the training data. Binary Cross-Entropy loss is used here.
- Action selector
- This is the RL part of the paper, which is, basically, a $Q$-function.
- Blocks
- Graph encoder
- An R-GCN that takes the predicted graph $\mathcal{G}$ and converts this to a hidden vector.
- Text encoder
- A transformer encoder that takes an observation and produces a hidden vector.
- Representation aggregator
- A network similar to $f_\Delta$ in the graph updater component;
- Scorer
- This is like the last layer of a $Q$-function, but it takes an action representation as an input, and outputs a scalar.
- The whole thing is trained as a Double $Q$-learning with prioritized experience replay and multi-step returns. Since, there is a distribution of tasks, they are all sampled at the beginning of the episode, and all the games rollouts are stored in the same buffer.
And?
- This was not the easiest paper to read. There are so many building blocks that I was confused 90% of the time about what goes where. In particular, it took me a while to find it in the appendix, that the R-GCN works on a complete graph and the predicted graph $\mathcal{G}$ is not really an adjacency matrix used for a GNN update.
- IMHO, the most interesting bit of this work is the baselines and the ablations.
- pre-training the graph updater on the ground-truth graphs does not work as well as vanilla GATA which is surprising. However, this is a bit more complicated than just pre-training using the ground truth data, since the output of the graph updater is discrete here. The authors hypothesise that the loss in performance comes from compounding and round off error when making discrete operations on the graphs.
- At the same time, using the ground truth data to do action selection works better, most likely, because the task becomes fully observable.
- The results are impressive, but it still feels so weird to convert graphs to vectors and back!
This note is a part of my paper notes series. You can find more here or on Twitter. I also have a blog.