<?xml version="1.0" encoding="utf-8" standalone="yes"?><rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom"><channel><title>Publications | Tobias Nauen</title><link>https://nauen-it.de/publications/</link><atom:link href="https://nauen-it.de/publications/index.xml" rel="self" type="application/rss+xml"/><description>Publications</description><generator>HugoBlox Kit (https://hugoblox.com)</generator><language>en-us</language><image><url>https://nauen-it.de/media/icon.svg</url><title>Publications</title><link>https://nauen-it.de/publications/</link></image><item><title>When Pretty Isn't Useful: Investigating Why Modern Text-to-Image Models Fail as Reliable Training Data Generators</title><link>https://nauen-it.de/publications/when-pretty-isnt-useful/</link><pubDate>Mon, 23 Feb 2026 00:00:00 +0000</pubDate><guid>https://nauen-it.de/publications/when-pretty-isnt-useful/</guid><description>
&lt;div class="callout flex items-baseline gap-2 px-3 py-2 mb-4 rounded-md border-l-4 bg-primary-50 dark:bg-primary-900/30 border-primary-500"
data-callout="note"
data-callout-metadata=""&gt;
&lt;span class="callout-icon shrink-0 translate-y-0.5 text-primary-600 dark:text-primary-400"&gt;
&lt;svg height="20" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="m11.25 11.25l.041-.02a.75.75 0 0 1 1.063.852l-.708 2.836a.75.75 0 0 0 1.063.853l.041-.021M21 12a9 9 0 1 1-18 0a9 9 0 0 1 18 0m-9-3.75h.008v.008H12z"/&gt;&lt;/svg&gt;
&lt;/span&gt;
&lt;div class="callout-content text-base dark:text-neutral-300"&gt;
&lt;div class="callout-body"&gt;&lt;p&gt;For more information, see the
.&lt;/p&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;/div&gt;</description></item><item><title>PRISM: Diversifying Dataset Distillation by Decoupling Architectural Priors</title><link>https://nauen-it.de/publications/prism/</link><pubDate>Thu, 13 Nov 2025 13:00:00 +0000</pubDate><guid>https://nauen-it.de/publications/prism/</guid><description>
&lt;div class="callout flex items-baseline gap-2 px-3 py-2 mb-4 rounded-md border-l-4 bg-primary-50 dark:bg-primary-900/30 border-primary-500"
data-callout="note"
data-callout-metadata=""&gt;
&lt;span class="callout-icon shrink-0 translate-y-0.5 text-primary-600 dark:text-primary-400"&gt;
&lt;svg height="20" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="m11.25 11.25l.041-.02a.75.75 0 0 1 1.063.852l-.708 2.836a.75.75 0 0 0 1.063.853l.041-.021M21 12a9 9 0 1 1-18 0a9 9 0 0 1 18 0m-9-3.75h.008v.008H12z"/&gt;&lt;/svg&gt;
&lt;/span&gt;
&lt;div class="callout-content text-base dark:text-neutral-300"&gt;
&lt;div class="callout-body"&gt;&lt;p&gt;For more information, see the
.&lt;/p&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;/div&gt;</description></item><item><title>HyperCore: Coreset Selection under Noise via Hypersphere Models</title><link>https://nauen-it.de/publications/hypercore/</link><pubDate>Fri, 26 Sep 2025 13:00:00 +0000</pubDate><guid>https://nauen-it.de/publications/hypercore/</guid><description>
&lt;div class="callout flex items-baseline gap-2 px-3 py-2 mb-4 rounded-md border-l-4 bg-primary-50 dark:bg-primary-900/30 border-primary-500"
data-callout="note"
data-callout-metadata=""&gt;
&lt;span class="callout-icon shrink-0 translate-y-0.5 text-primary-600 dark:text-primary-400"&gt;
&lt;svg height="20" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="m11.25 11.25l.041-.02a.75.75 0 0 1 1.063.852l-.708 2.836a.75.75 0 0 0 1.063.853l.041-.021M21 12a9 9 0 1 1-18 0a9 9 0 0 1 18 0m-9-3.75h.008v.008H12z"/&gt;&lt;/svg&gt;
&lt;/span&gt;
&lt;div class="callout-content text-base dark:text-neutral-300"&gt;
&lt;div class="callout-body"&gt;&lt;p&gt;For more information, see the
.&lt;/p&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;/div&gt;</description></item><item><title>SubZeroCore: A Submodular Approach with Zero Training for Coreset Selection</title><link>https://nauen-it.de/publications/subzerocore/</link><pubDate>Fri, 26 Sep 2025 13:00:00 +0000</pubDate><guid>https://nauen-it.de/publications/subzerocore/</guid><description>
&lt;div class="callout flex items-baseline gap-2 px-3 py-2 mb-4 rounded-md border-l-4 bg-primary-50 dark:bg-primary-900/30 border-primary-500"
data-callout="note"
data-callout-metadata=""&gt;
&lt;span class="callout-icon shrink-0 translate-y-0.5 text-primary-600 dark:text-primary-400"&gt;
&lt;svg height="20" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="m11.25 11.25l.041-.02a.75.75 0 0 1 1.063.852l-.708 2.836a.75.75 0 0 0 1.063.853l.041-.021M21 12a9 9 0 1 1-18 0a9 9 0 0 1 18 0m-9-3.75h.008v.008H12z"/&gt;&lt;/svg&gt;
&lt;/span&gt;
&lt;div class="callout-content text-base dark:text-neutral-300"&gt;
&lt;div class="callout-body"&gt;&lt;p&gt;For more information, see the
.&lt;/p&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;/div&gt;</description></item><item><title>When 512×512 is not Enough: Local Degradation-Aware Multi-Diffusion for Extreme Image Super-Resolution</title><link>https://nauen-it.de/publications/zoomed-in-diffused-out/</link><pubDate>Sun, 14 Sep 2025 13:00:00 +0000</pubDate><guid>https://nauen-it.de/publications/zoomed-in-diffused-out/</guid><description>
&lt;div class="callout flex items-baseline gap-2 px-3 py-2 mb-4 rounded-md border-l-4 bg-primary-50 dark:bg-primary-900/30 border-primary-500"
data-callout="note"
data-callout-metadata=""&gt;
&lt;span class="callout-icon shrink-0 translate-y-0.5 text-primary-600 dark:text-primary-400"&gt;
&lt;svg height="20" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="m11.25 11.25l.041-.02a.75.75 0 0 1 1.063.852l-.708 2.836a.75.75 0 0 0 1.063.853l.041-.021M21 12a9 9 0 1 1-18 0a9 9 0 0 1 18 0m-9-3.75h.008v.008H12z"/&gt;&lt;/svg&gt;
&lt;/span&gt;
&lt;div class="callout-content text-base dark:text-neutral-300"&gt;
&lt;div class="callout-body"&gt;&lt;p&gt;For more information, see the
.&lt;/p&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;/div&gt;
&lt;p&gt;&lt;strong&gt;Associated Projects:&lt;/strong&gt;
,
,
&lt;/p&gt;</description></item><item><title>ForAug: Recombining Foregrounds and Backgrounds to Improve Vision Transformer Training with Bias Mitigation</title><link>https://nauen-it.de/publications/foraug/</link><pubDate>Wed, 12 Mar 2025 00:00:00 +0000</pubDate><guid>https://nauen-it.de/publications/foraug/</guid><description>&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;&lt;img alt="gif"
src="https://nauen-it.de/publications/foraug/images/foraug-gif.gif"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;h1 id="introduction"&gt;Introduction&lt;/h1&gt;
&lt;p&gt;Image classification – teaching computers to label images – is a cornerstone of AI vision, powering everything from medical diagnosis to autonomous driving.
Datasets like ImageNet have been crucial, especially with the rise of powerful models like Vision Transformers (ViTs).&lt;/p&gt;
&lt;p&gt;However, unlike older Convolutional Neural Networks (CNNs), ViTs don&amp;rsquo;t inherently understand that an object remains the same regardless of its position in an image (they lack &amp;ldquo;translation equivariance&amp;rdquo;).
Standard data augmentation techniques (like flipping or cropping) help, but they weren&amp;rsquo;t specifically designed for this trait of ViTs.&lt;/p&gt;
&lt;p&gt;To tackle these problems, we propose &lt;strong&gt;ForAug&lt;/strong&gt;, a novel data augmentation for ViTs.
The core idea?
Make the spatial relationships explicit in the training data.
ForAug achieves this by:&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;Separating foreground objects from their backgrounds in the dataset.&lt;/li&gt;
&lt;li&gt;Recombining these objects with different backgrounds on-the-fly during training.&lt;/li&gt;
&lt;li&gt;Controlling the object&amp;rsquo;s size and position during this recombination.&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;The results?
Training ViTs with ForAug instead of standard ImageNet boosts accuracy by up to 4.5 percentage points on ImageNet classification and significantly cuts error rates (up to 39.3% reduction) on downstream tasks.&lt;/p&gt;
&lt;p&gt;Furthermore, ForAug provides powerful new ways to analyze model biases.
Researchers can now precisely measure:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;Background Robustness: How much does the background influence the prediction?&lt;/li&gt;
&lt;li&gt;Foreground Focus: Does the model correctly focus on the main object?&lt;/li&gt;
&lt;li&gt;Center &amp;amp; Size Bias: Is the model overly reliant on objects being centered or a specific size?&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;Training with ForAug demonstrably reduces these biases, leading to more robust models.&lt;/p&gt;
&lt;div class="callout flex items-baseline gap-2 px-3 py-2 mb-4 rounded-md border-l-4 bg-primary-50 dark:bg-primary-900/30 border-primary-500"
data-callout="note"
data-callout-metadata=""&gt;
&lt;span class="callout-icon shrink-0 translate-y-0.5 text-primary-600 dark:text-primary-400"&gt;
&lt;svg height="20" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="m11.25 11.25l.041-.02a.75.75 0 0 1 1.063.852l-.708 2.836a.75.75 0 0 0 1.063.853l.041-.021M21 12a9 9 0 1 1-18 0a9 9 0 0 1 18 0m-9-3.75h.008v.008H12z"/&gt;&lt;/svg&gt;
&lt;/span&gt;
&lt;div class="callout-content text-base dark:text-neutral-300"&gt;
&lt;div class="callout-body"&gt;&lt;p&gt;This post is just a short overview over ForAug. For more information, see the
.&lt;/p&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;/div&gt;
&lt;h1 id="foraug-method"&gt;ForAug (Method)&lt;/h1&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="ForAug Flowchart"
srcset="https://nauen-it.de/publications/foraug/images/fig-2_hu_2f0955c903beee0e.webp 320w, https://nauen-it.de/publications/foraug/images/fig-2_hu_5313138d6618e3b9.webp 480w, https://nauen-it.de/publications/foraug/images/fig-2_hu_70ceae1cbeb25b52.webp 760w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://nauen-it.de/publications/foraug/images/fig-2_hu_2f0955c903beee0e.webp"
width="760"
height="291"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
So, how does ForAug actually build these dynamic training images?
The process, visualized above, involves two main stages: an offline &lt;strong&gt;Segmentation&lt;/strong&gt; stage and an online &lt;strong&gt;Recombination&lt;/strong&gt; stage.&lt;/p&gt;
&lt;h2 id="segmentation"&gt;Segmentation&lt;/h2&gt;
&lt;p&gt;The process kicks off with the Segmentation stage, a one-time, offline preparation step performed before model training even begins.
Think of it as carefully prepping the visual ingredients.
Here, we leverage the state-of-the-art Grounded SAM segmentation model, guiding it with the known class label of each image (e.g., instructing it to specifically find the &amp;lsquo;golden retriever&amp;rsquo;) to precisely isolate the main subject.
Once the foreground object is digitally &amp;lsquo;cut out&amp;rsquo;, an object removal or &amp;lsquo;inpainting&amp;rsquo; model intelligently fills the resulting hole in the original background, ensuring the backdrop looks natural and plausible.
Crucially, not all generated assets make the cut; a filtering step employs other pre-trained AI models to assess quality.
This ensures only clearly defined foregrounds and clean backgrounds – ones that don&amp;rsquo;t inadvertently give away the object&amp;rsquo;s identity or look overly artificial – are selected.
This meticulous preparation yields the core assets for ForAug: a collection of ready-to-use foreground objects (with transparency) and a diverse pool of cleaned-up backgrounds.&lt;/p&gt;
&lt;h2 id="recombination"&gt;Recombination&lt;/h2&gt;
&lt;p&gt;With the assets prepared, the real action unfolds during the Recombination stage, which happens dynamically online while the Vision Transformer is training.
This is where ForAug truly comes alive, creating new training examples on the fly.
For every foreground object prepared in the first stage, the system randomly selects a background to pair it with.
This background might be the object&amp;rsquo;s original one, perhaps one from another image belonging to the same object class, or even a completely unrelated background drawn from the entire dataset to maximize contextual variety.
The chosen foreground object is then randomly resized (within sensible limits based on its original appearance) and placed at a random position onto this background canvas.
To create a more seamless integration, a subtle smoothing effect is applied to the object&amp;rsquo;s edges where it meets the new background.
Only after this dynamic composition is complete does the resulting image undergo the standard data augmentation techniques commonly used in AI training, like random color shifts or minor flips.
This constant mixing-and-matching means that each time the AI cycles through the training data, it encounters familiar objects in entirely new visual contexts.
This directly forces the ViT to learn robust features that identify the object itself, effectively teaching it the spatial invariance that doesn&amp;rsquo;t come built-in, by demonstrating repeatedly that appearance, not specific placement or background, is what defines the object.&lt;/p&gt;
&lt;h1 id="experiments"&gt;Experiments&lt;/h1&gt;
&lt;h2 id="image-classification-results"&gt;Image Classification Results&lt;/h2&gt;
&lt;p&gt;We compare training on ImageNet with and without ForAug for 10 different models:
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="ImageNet results"
srcset="https://nauen-it.de/publications/foraug/images/foraug-imagenet-results_hu_d7381cc18cb8227f.webp 320w, https://nauen-it.de/publications/foraug/images/foraug-imagenet-results_hu_9a346107ce9df01c.webp 480w, https://nauen-it.de/publications/foraug/images/foraug-imagenet-results_hu_851c625e34531f5c.webp 760w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://nauen-it.de/publications/foraug/images/foraug-imagenet-results_hu_d7381cc18cb8227f.webp"
width="760"
height="647"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
We find that training with ForAug increases the accuracy of every model by up to 4.5%.
It also combats the overfitting problem of larger models.&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="Downstream Results"
srcset="https://nauen-it.de/publications/foraug/images/foraug-downstream-results_hu_bc13a772bd62ba95.webp 320w, https://nauen-it.de/publications/foraug/images/foraug-downstream-results_hu_4093f7149dcda2cb.webp 480w, https://nauen-it.de/publications/foraug/images/foraug-downstream-results_hu_67cdcd2209a42509.webp 671w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://nauen-it.de/publications/foraug/images/foraug-downstream-results_hu_bc13a772bd62ba95.webp"
width="671"
height="760"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
When finetuning these models on 5 fine-grained down-stream datasets, we find that the ForAug-pretrained models consistently outperform the ImageNet-pretrained ones.
Especially when looking at the transformer-based models.&lt;/p&gt;
&lt;h2 id="model-robustness"&gt;Model Robustness&lt;/h2&gt;
&lt;p&gt;We also evaluate multiple robustness metrics.&lt;/p&gt;
&lt;h3 id="background-robustness"&gt;Background Robustness&lt;/h3&gt;
&lt;p&gt;We check the background robustness of models, by inspecting the accuracy-change when evaluating with ForAug using backgrounds from the &lt;em&gt;same&lt;/em&gt; class compared to backgrounds from &lt;em&gt;all&lt;/em&gt; classes:&lt;/p&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="Background Robustness Scores"
srcset="https://nauen-it.de/publications/foraug/images/foraug-background-robustness_hu_2489c03de26af797.webp 320w, https://nauen-it.de/publications/foraug/images/foraug-background-robustness_hu_dc3b8b16d1a5fc3a.webp 480w, https://nauen-it.de/publications/foraug/images/foraug-background-robustness_hu_289f170ed90e6dd3.webp 760w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://nauen-it.de/publications/foraug/images/foraug-background-robustness_hu_2489c03de26af797.webp"
width="760"
height="224"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
Training with ForAug reduces the &lt;em&gt;Background Gap&lt;/em&gt; for all transformer models.&lt;/p&gt;
&lt;h3 id="foreground-focus"&gt;Foreground Focus&lt;/h3&gt;
&lt;p&gt;Since we have the foreground segmentation masks, we can also investigate the foreground focus of the trained models.
For this, we utilize different input-importance metrics like GradCAM and IntegratedGradients (IG).
We define a models foreground focus, by how much more it focuses on the foreground object compared to a uniform distribution:
&lt;/p&gt;
$$
\text{FG Focus}(M; \text{img}) = \frac{\text{Area}(\text{img}) \hspace{5pt} \text{Importance}_M(\text{fg})}{\text{Area}(\text{fg}) \hspace{5pt} \text{Importance}_M(\text{img})}
$$&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="Foreground Focus Scores"
srcset="https://nauen-it.de/publications/foraug/images/foraug-foreground-focus_hu_1afeb71ecf915471.webp 320w, https://nauen-it.de/publications/foraug/images/foraug-foreground-focus_hu_4efb61ce3baf1f8b.webp 480w, https://nauen-it.de/publications/foraug/images/foraug-foreground-focus_hu_467c7e1ebb906659.webp 760w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://nauen-it.de/publications/foraug/images/foraug-foreground-focus_hu_1afeb71ecf915471.webp"
width="760"
height="204"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
We find that training with ForAug mostly significantly improves the foreground focus of all models.&lt;/p&gt;
&lt;h3 id="center-bias"&gt;Center Bias&lt;/h3&gt;
&lt;p&gt;Since we can freely change the object&amp;rsquo;s position and size, we can evaluate the model bias when the position changes.
For this, we subdivide the image into $3 \times 3$ sections (nonants) and place each object only in one nonant.
We then compare the accuracy of a model when an object is in a specific nonant to when it&amp;rsquo;s in the center nonant.&lt;/p&gt;
&lt;p&gt;Our center-bias score is defined at the mean of (1) the worst accuracy in a corner and (2) the worst accuracy on an edge, relative to the accuracy in the center.
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="Center Bias Table"
srcset="https://nauen-it.de/publications/foraug/images/foraug-center-bias_hu_4292720d3de7ff2c.webp 320w, https://nauen-it.de/publications/foraug/images/foraug-center-bias_hu_86b74bbd53280813.webp 480w, https://nauen-it.de/publications/foraug/images/foraug-center-bias_hu_1c2231712f30f607.webp 483w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://nauen-it.de/publications/foraug/images/foraug-center-bias_hu_4292720d3de7ff2c.webp"
width="483"
height="760"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;We visualize the center bias for 3 instantiations of each model.
Training with ForAug significantly reduces the center bias; especially of larger transformers.
We also find that when training on ImageNet, model consistently perform better when an object is on the right side of an image compared to the left side (even though we use 50% random flipping during training of all models).&lt;/p&gt;
&lt;h3 id="size-bias"&gt;Size Bias&lt;/h3&gt;
&lt;p&gt;We vary the object size by an additional factor of $f_\text{size}$ to see how the model accuracy changes relative to $f_\text{size} = 1$.
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="Size Bias Plot"
srcset="https://nauen-it.de/publications/foraug/images/foraug-size-bias_hu_6d10bfb297ada5cb.webp 320w, https://nauen-it.de/publications/foraug/images/foraug-size-bias_hu_7f5b922edf83b92b.webp 480w, https://nauen-it.de/publications/foraug/images/foraug-size-bias_hu_13974a67c0bc523a.webp 760w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://nauen-it.de/publications/foraug/images/foraug-size-bias_hu_6d10bfb297ada5cb.webp"
width="760"
height="377"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;Using ForAug significantly reduces the accuracy drop-off when going towards smaller objects.
These gains come on top off the overall better accuracy (at $f_\text{size} = 1$).&lt;/p&gt;
&lt;h1 id="conclusion"&gt;Conclusion&lt;/h1&gt;
&lt;p&gt;So, what&amp;rsquo;s the big takeaway from ForAug?
This research introduces a genuinely novel data augmentation scheme designed specifically to enhance how Vision Transformers learn to classify images.
By cleverly separating foreground objects from their backgrounds and dynamically recombining them during training, ForAug tackles a key characteristic of Transformer models head-on.&lt;/p&gt;
&lt;p&gt;As the results clearly demonstrate, this dynamic approach pays off significantly.
Training models with ForAug leads to substantial performance boosts on the standard ImageNet benchmark and translates to impressive gains on related fine-grained classification tasks downstream.&lt;/p&gt;
&lt;p&gt;But the impact of ForAug extends beyond just improving accuracy scores.
It also provides a powerful and much-needed framework for analyzing model behavior and uncovering hidden biases.
Crucially, the experiments show that training with ForAug doesn&amp;rsquo;t just highlight these biases – it actively reduces them.
This results in models that are not only more accurate but also more robust, reliable, and generalizable to varied real-world conditions.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;Associated Projects:&lt;/strong&gt;
,
,
&lt;/p&gt;</description></item><item><title>Which Transformer to Favor: A Comparative Analysis of Efficiency in Vision Transformers</title><link>https://nauen-it.de/publications/wtf-benchmark/</link><pubDate>Fri, 28 Feb 2025 00:00:00 +0000</pubDate><guid>https://nauen-it.de/publications/wtf-benchmark/</guid><description>&lt;h1 id="introduction"&gt;Introduction&lt;/h1&gt;
&lt;p&gt;The Transformer architecture
is one of the most successful models in deep learning, outperforming traditional models in multiple domains from language modeling to computer vision.
However, a major challenge in working with Transformer models is their computational complexity of $\mathcal O(N^2)$ in the size of the input $N$.
Therefore, researchers have proposed a multitude of modifications to overcome this hurdle and make Transformers more efficient.&lt;/p&gt;
&lt;p&gt;However, it is unclear which modifications and overall strategies are the most efficient.
That&amp;rsquo;s why in this paper, we will answer the following questions for the domain of &lt;strong&gt;image classification&lt;/strong&gt;:&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;Which specific modifications&lt;/strong&gt; and overall strategies are the most efficient?&lt;/li&gt;
&lt;li&gt;Are these modifications &lt;strong&gt;even worth considering&lt;/strong&gt; over the baseline transformer?&lt;/li&gt;
&lt;li&gt;What &lt;strong&gt;other dimensions&lt;/strong&gt; influence efficiency, and how can I scale up my setup efficiently?&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;We tackle these questions by training more than &lt;strong&gt;45 transformer variants&lt;/strong&gt; from scratch, ensuring fair and comparable evaluation conditions.
These transformer variants have been proposed to increase the efficiency for the domains of language or computer vision.
Then we measure their &lt;em&gt;speed&lt;/em&gt; and &lt;em&gt;memory requirements&lt;/em&gt;, both at &lt;em&gt;training&lt;/em&gt; and &lt;em&gt;inference&lt;/em&gt; time.
We additionally compare to the theoretical metrics of &lt;em&gt;parameters&lt;/em&gt; and &lt;em&gt;FLOPs&lt;/em&gt;.
Our analysis is based on the Pareto front, the set of models that provide an optimal tradeoff between model performance and one aspect of efficiency.
It lets us analyze the complex multidimensional tradeoffs involved in judging efficiency.
In out plots, Pareto optimal models have a black dot, while the others have a white dot.
For an example, see
.&lt;/p&gt;
&lt;div class="callout flex items-baseline gap-2 px-3 py-2 mb-4 rounded-md border-l-4 bg-primary-50 dark:bg-primary-900/30 border-primary-500"
data-callout="note"
data-callout-metadata=""&gt;
&lt;span class="callout-icon shrink-0 translate-y-0.5 text-primary-600 dark:text-primary-400"&gt;
&lt;svg height="20" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="m11.25 11.25l.041-.02a.75.75 0 0 1 1.063.852l-.708 2.836a.75.75 0 0 0 1.063.853l.041-.021M21 12a9 9 0 1 1-18 0a9 9 0 0 1 18 0m-9-3.75h.008v.008H12z"/&gt;&lt;/svg&gt;
&lt;/span&gt;
&lt;div class="callout-content text-base dark:text-neutral-300"&gt;
&lt;div class="callout-body"&gt;&lt;p&gt;To see the interactive plots, go down to the
.&lt;/p&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;/div&gt;
&lt;h1 id="efficient-transformers-for-computer-vision"&gt;Efficient Transformers for Computer Vision&lt;/h1&gt;
&lt;h2 id="basics-of-the-transformer-architecture"&gt;Basics of the Transformer Architecture&lt;/h2&gt;
&lt;p&gt;We briefly describe the key elements of ViT
(the Transformer baseline for image classification), that have been studied to make it more efficient, as well as its key bottleneck: the $\mathcal O(N^2)$ computational complexity of self-attention.
ViT is an adaption of the original Transformer, taking an image as an input, which is converted into a sequence of non-overlapping patches of size $p \times p$ (usually $p = 16$).
Each patch is linearly embedded into a token of size $d$, with a positional encoding being added.
A classification token &lt;code&gt;[CLS]&lt;/code&gt; is appended to the sequence, which is then fed through a Transformer encoder.
There, the self-attention mechanism computes the attention weights $A$ from the queries $Q \in \mathbb R^{N \times d}$ and keys $K \in \mathbb R^{N \times d}$ for each token from the sequence:&lt;/p&gt;
$$
A = \text{softmax}\left( \frac{QK^\top}{\sqrt{d_\text{head}}} \right) \in \mathbb R^{N \times N}
$$&lt;p&gt;This matrix encodes the global interactions between every possible pair of tokens, but it&amp;rsquo;s also the reason for the inherent $\mathcal O(N^2)$ computational complexity of the attention mechanism.
The output of attention is a sum over the values $V$ weighted by the attention weights: $X_\text{out} = AV$.
After self-attention, the sequence elements are passed through a 2-layer MLP.
In the end, only the &lt;code&gt;[CLS]&lt;/code&gt; token is used for the classification decision.&lt;/p&gt;
&lt;h2 id="efficiency-improving-changes"&gt;Efficiency-Improving Changes&lt;/h2&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="Two level taxonomy for classification of efficiency-improving changes and strategies"
srcset="https://nauen-it.de/publications/wtf-benchmark/images/wtf_two_level_taxonomy_hu_8706818db19e3f4f.webp 320w, https://nauen-it.de/publications/wtf-benchmark/images/wtf_two_level_taxonomy_hu_376d09f3b20e2d46.webp 480w, https://nauen-it.de/publications/wtf-benchmark/images/wtf_two_level_taxonomy_hu_4437156d62efa8.webp 760w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://nauen-it.de/publications/wtf-benchmark/images/wtf_two_level_taxonomy_hu_8706818db19e3f4f.webp"
width="760"
height="194"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;We systematically classify the efficient models using a two step approach:&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;strong&gt;Where&lt;/strong&gt; does the model change the baseline ViT: Ath the &lt;em&gt;token-mixing mechanism&lt;/em&gt;, the &lt;em&gt;token sequence&lt;/em&gt;, or at the &lt;em&gt;MLP block&lt;/em&gt;?&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;How&lt;/strong&gt; and using what strategy does the model change the baseline?&lt;/li&gt;
&lt;/ol&gt;
&lt;h3 id="i-token-mixing"&gt;(i) Token Mixing&lt;/h3&gt;
&lt;p&gt;The first and most popular approach is to change the token mixing mechanism, which directly tackles the $\mathcal O(N^2)$ computational complexity of self-attention.
We identify 7 strategies for changing the token mixing mechanism to make it more efficient:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;Low-Rank Attention&lt;/strong&gt; leverages the fact that $QK^\top \in \mathbb R^{N \times N}$ is a matrix of rank $r \leq d \ll N$ and approximates it by using a low-rank representation.&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Sparse Attention&lt;/strong&gt; builds on most of the attention values being very small and only explicitly calculates a subset of values of $A$.&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Fixed Attention&lt;/strong&gt; uses a fixed the attention matrix for all samples.&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Kernel Attention&lt;/strong&gt; splits the $\text{softmax}$ into two functions to be applied to $Q$ and $K$ individually, so $A$ does not have to be calculated explicitly:
$$
X_\text{out} = \phi(Q) \phi(K)^\top V.
$$&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Hybrid Attention&lt;/strong&gt; combines the attention mechanism with convolution layers.&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Fourier Attention&lt;/strong&gt; uses the Fast Fourier Transform (FFT) to calculate the interactions in Fourier space with $\mathcal O(N \log N)$ complexity.&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Non-Attention Shuffling&lt;/strong&gt; refers to other techniques of capturing interactions without using attention.&lt;/li&gt;
&lt;/ul&gt;
&lt;h3 id="ii-token-sequence"&gt;(ii) Token Sequence&lt;/h3&gt;
&lt;p&gt;Models that change up the token sequence are more prevalent in CV compared to NLP.
The idea is to remove redundant information and in doing so, using the $\mathcal O(N^2)$ complexity to our advantage.
Removing 30% of the tokens reduces the computational cost of self-attention by approximately 50%.
The strategies we identify are:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;&lt;strong&gt;Token Removal&lt;/strong&gt;: Removing unimportant tokens without losing critical information.&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Token Merging&lt;/strong&gt;: Merging tokens to remove redundant information.&lt;/li&gt;
&lt;li&gt;&lt;strong&gt;Summary Tokens&lt;/strong&gt;: Condensing the information from the sequence into a small set of new tokens.&lt;/li&gt;
&lt;/ul&gt;
&lt;h3 id="iii-mlp-block"&gt;(iii) MLP Block&lt;/h3&gt;
&lt;p&gt;The final way of changing the architecture was only taken by two methods.
Their idea was to move computations from self-attention into the efficient MLP blocks.
This is done by expanding the MLPs or exchanging self-attention layers for more MLPs.&lt;/p&gt;
&lt;p&gt;&lt;a name="list-of-models"&gt;&lt;/a&gt;&lt;/p&gt;
&lt;details&gt;
&lt;summary&gt;List of Models&lt;/summary&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th style="text-align: left"&gt;Where?&lt;/th&gt;
&lt;th style="text-align: left"&gt;What?&lt;/th&gt;
&lt;th style="text-align: left"&gt;Model Name&lt;/th&gt;
&lt;th style="text-align: left"&gt;Paper Title&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;strong&gt;Token Mixing&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;strong&gt;Low-Rank Attention&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;Linformer&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;Nyströmformer&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;XCiT&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;strong&gt;Sparse Attention&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;Swin Transformer&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;SwinV2&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;HaloNet&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;Routing Transformer&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;Sinkhorn Transformer&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;Informer&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;Wave-ViT&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;strong&gt;Fixed Attention&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;Synthesizer&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;strong&gt;Kernel Attention&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;Performer&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;Poly-SA&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;Linear Transformer&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;SLAB&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;Hydra ViT&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;strong&gt;Hybrid Attention&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;EfficientFormerV2&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;EfficientViT&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;Next-ViT&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;CvT&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;ResT&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;CoaT&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;strong&gt;Fourier Attention&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;FNet&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;GFNet&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;AFNO&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;strong&gt;Non-Attention Shuffling&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;MLP-Mixer&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;FastViT&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;EfficientMod&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;FocalNet&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;SwiftFormer&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;strong&gt;Token Sequence&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;strong&gt;Token Removal&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;DynamicViT&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;A-ViT&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;EViT&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;strong&gt;Token Merging&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;ToMe&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;strong&gt;Summary Tokens&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;CaiT&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;Token Learner&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;STViT&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;strong&gt;MLP Block&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;strong&gt;More MLPs&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;Switch Transformer&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;HiViT&lt;/td&gt;
&lt;td style="text-align: left"&gt;
&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;/details&gt;
&lt;h1 id="experimental-design"&gt;Experimental Design&lt;/h1&gt;
&lt;p&gt;We conduct a series of over 200 experiments on more than 45 models.&lt;/p&gt;
&lt;h2 id="training-pipeline"&gt;Training Pipeline&lt;/h2&gt;
&lt;p&gt;We compare models on even grounds by training from scratch with a standardized pipeline.
This pipeline is based on DeiT III
, an updated version of DeiT
.
To reduce bias, our pipeline is relatively simple and only consists of elements commonly used in CV.
In particular, we refrain from using knowledge distillation to prevent introducing bias from the choice of teacher model.
Any orthogonal techniques, like quantization, sample selection, and others, are not included as they can be applied to every model and would manifest as a systematic offset in the results.
To avoid issues from limited training data, we pre-train all models on ImageNet-21k
.&lt;/p&gt;
&lt;details&gt;
&lt;summary&gt;Training Hyperparameters&lt;/summary&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th style="text-align: left"&gt;&lt;/th&gt;
&lt;th style="text-align: center"&gt;Pretrain&lt;/th&gt;
&lt;th style="text-align: center"&gt;Finetune&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;Dataset&lt;/td&gt;
&lt;td style="text-align: center"&gt;ImageNet-21k&lt;/td&gt;
&lt;td style="text-align: center"&gt;ImageNet-1k&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;Epochs&lt;/td&gt;
&lt;td style="text-align: center"&gt;90&lt;/td&gt;
&lt;td style="text-align: center"&gt;50&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;LR&lt;/td&gt;
&lt;td style="text-align: center"&gt;$3 \times 10^{-3}$&lt;/td&gt;
&lt;td style="text-align: center"&gt;$3 \times 10^{-4}&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;Schedule&lt;/td&gt;
&lt;td style="text-align: center"&gt;cosine decay&lt;/td&gt;
&lt;td style="text-align: center"&gt;cosine decay&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;Batch Size&lt;/td&gt;
&lt;td style="text-align: center"&gt;2048&lt;/td&gt;
&lt;td style="text-align: center"&gt;2048&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;Warmup Schedule&lt;/td&gt;
&lt;td style="text-align: center"&gt;linear&lt;/td&gt;
&lt;td style="text-align: center"&gt;linear&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;Warmup Epochs&lt;/td&gt;
&lt;td style="text-align: center"&gt;5&lt;/td&gt;
&lt;td style="text-align: center"&gt;5&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;Weight Decay&lt;/td&gt;
&lt;td style="text-align: center"&gt;0.02&lt;/td&gt;
&lt;td style="text-align: center"&gt;0.02&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;Gradient Clipping&lt;/td&gt;
&lt;td style="text-align: center"&gt;1.0&lt;/td&gt;
&lt;td style="text-align: center"&gt;1.0&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;Label Smoothing&lt;/td&gt;
&lt;td style="text-align: center"&gt;0.1&lt;/td&gt;
&lt;td style="text-align: center"&gt;0.1&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;Drop Path Rate&lt;/td&gt;
&lt;td style="text-align: center"&gt;0.05&lt;/td&gt;
&lt;td style="text-align: center"&gt;0.05&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;Optimizer&lt;/td&gt;
&lt;td style="text-align: center"&gt;Lamb&lt;/td&gt;
&lt;td style="text-align: center"&gt;Lamb&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;Dropout Rate&lt;/td&gt;
&lt;td style="text-align: center"&gt;0.0&lt;/td&gt;
&lt;td style="text-align: center"&gt;0.0&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;Mixed Precision&lt;/td&gt;
&lt;td style="text-align: center"&gt;&amp;#x2705;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&amp;#x2705;&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;Augmentation&lt;/td&gt;
&lt;td style="text-align: center"&gt;3-Augment&lt;/td&gt;
&lt;td style="text-align: center"&gt;3-Augment&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;Image Resolution&lt;/td&gt;
&lt;td style="text-align: center"&gt;$224 \times 224$ or $192 \times 192$&lt;/td&gt;
&lt;td style="text-align: center"&gt;$224 \times 224$ or $384 \times 384$&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;GPUs&lt;/td&gt;
&lt;td style="text-align: center"&gt;4 NVIDIA A100&lt;/td&gt;
&lt;td style="text-align: center"&gt;4 or 8 NVIDIA A100&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;/details&gt;
&lt;h2 id="efficiency-metrics"&gt;Efficiency Metrics&lt;/h2&gt;
&lt;p&gt;We track the following metrics for evaluating the model efficiency:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;Number of Parameters&lt;/li&gt;
&lt;li&gt;FLOPs&lt;/li&gt;
&lt;li&gt;Training time in GPU-hours at batch size 2048 for the full 50 epochs of finetuning on an A100 GPU&lt;/li&gt;
&lt;li&gt;Inference throughput in images per second at the optimal batch size on an A100 GPU&lt;/li&gt;
&lt;li&gt;Training memory over all GPUs during finetuing at batch size 2048&lt;/li&gt;
&lt;li&gt;Inference memory on a single GPU at batch size 1; the minimum amount of memory needed for inference&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;For comparability, the empirical metrics are measured using the same setup.&lt;/p&gt;
&lt;h1 id="results"&gt;Results&lt;/h1&gt;
&lt;h2 id="improved-training-pipeline"&gt;Improved Training Pipeline&lt;/h2&gt;
&lt;p&gt;To validate the fairness of our training pipeline, we validate our ImageNet-1k accuracy with the original papers&amp;rsquo; (whenever reported).&lt;/p&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th style="text-align: left"&gt;Model&lt;/th&gt;
&lt;th style="text-align: center"&gt;Orig. DeiT&lt;/th&gt;
&lt;th style="text-align: center"&gt;Orig. Acc.&lt;/th&gt;
&lt;th style="text-align: center"&gt;Our Acc.&lt;/th&gt;
&lt;th style="text-align: center"&gt;&lt;/th&gt;
&lt;th style="text-align: left"&gt;Model&lt;/th&gt;
&lt;th style="text-align: center"&gt;Orig. DeiT&lt;/th&gt;
&lt;th style="text-align: center"&gt;Orig. Acc.&lt;/th&gt;
&lt;th style="text-align: center"&gt;Our Acc.&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;ViT-S (DeiT)&lt;/td&gt;
&lt;td style="text-align: center"&gt;&amp;#x2705;&lt;/td&gt;
&lt;td style="text-align: center"&gt;79.8&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;82.54&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;ViT-S (DeiT III)&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;82.6&lt;/td&gt;
&lt;td style="text-align: center"&gt;82.54&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;XCiT-S&lt;/td&gt;
&lt;td style="text-align: center"&gt;&amp;#x2705;&lt;/td&gt;
&lt;td style="text-align: center"&gt;82.0&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;83.65&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;Swin-S&lt;/td&gt;
&lt;td style="text-align: center"&gt;&amp;#x2705;&lt;/td&gt;
&lt;td style="text-align: center"&gt;83.0&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;84.87&lt;/strong&gt;&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;Swin-V2-Ti&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;81.7&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;83.09&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;Wave-ViT-S&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;82.7&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;83.61&lt;/strong&gt;&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;Poly-SA-ViT-S&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;71.48&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;78.34&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;SLAB-S&lt;/td&gt;
&lt;td style="text-align: center"&gt;&amp;#x2705;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;80.0&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;78.70&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;EfficientFormer-V2-S0&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;75.7&lt;/strong&gt;${}^D$&lt;/td&gt;
&lt;td style="text-align: center"&gt;71.53&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;CvT-13&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;83.3&lt;/strong&gt;$\uparrow$&lt;/td&gt;
&lt;td style="text-align: center"&gt;82.35&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;CoaT-Ti&lt;/td&gt;
&lt;td style="text-align: center"&gt;&amp;#x2705;&lt;/td&gt;
&lt;td style="text-align: center"&gt;78.37&lt;/td&gt;
&lt;td style="text-align: center"&gt;78.42&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;EfficientViT-B2&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;82.7&lt;/strong&gt;$\uparrow$&lt;/td&gt;
&lt;td style="text-align: center"&gt;81.53&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;NextViT-S&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;82.5&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;83.92&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;ResT-S&lt;/td&gt;
&lt;td style="text-align: center"&gt;&amp;#x2705;&lt;/td&gt;
&lt;td style="text-align: center"&gt;79.6&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;79.92&lt;/strong&gt;&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;FocalNet-S&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;83.4&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;84.91&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;SwiftFormer-S&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;78.5&lt;/strong&gt;${}^D$&lt;/td&gt;
&lt;td style="text-align: center"&gt;76.41&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;FastViT-S12&lt;/td&gt;
&lt;td style="text-align: center"&gt;&amp;#x2705;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;79.8&lt;/strong&gt;$\uparrow$&lt;/td&gt;
&lt;td style="text-align: center"&gt;78.77&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;EfficientMod-S&lt;/td&gt;
&lt;td style="text-align: center"&gt;&amp;#x2705;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;81.0&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;80.21&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;GFNet-S&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;80.0&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;81.33&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;EViT-S&lt;/td&gt;
&lt;td style="text-align: center"&gt;&amp;#x2705;&lt;/td&gt;
&lt;td style="text-align: center"&gt;79.4&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;82.29&lt;/strong&gt;&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;DynamicViT-S&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;83.0&lt;/strong&gt;${}^D$&lt;/td&gt;
&lt;td style="text-align: center"&gt;81.09&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;EViT Fuse&lt;/td&gt;
&lt;td style="text-align: center"&gt;&amp;#x2705;&lt;/td&gt;
&lt;td style="text-align: center"&gt;79.5&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;81.96&lt;/strong&gt;&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;ToMe-ViT-S&lt;/td&gt;
&lt;td style="text-align: center"&gt;&amp;#x2705;&lt;/td&gt;
&lt;td style="text-align: center"&gt;79.42&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;82.11&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;TokenLearner-ViT-8&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;77.87$\downarrow$&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;80.66&lt;/strong&gt;&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;STViT-Swin-Ti&lt;/td&gt;
&lt;td style="text-align: center"&gt;&amp;#x2705;&lt;/td&gt;
&lt;td style="text-align: center"&gt;80.8&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;82.22&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: left"&gt;CaiT-S24&lt;/td&gt;
&lt;td style="text-align: center"&gt;&amp;#x2705;&lt;/td&gt;
&lt;td style="text-align: center"&gt;82.7&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;84.91&lt;/strong&gt;&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;p&gt;We find that 13 out of 26 papers base their training pipelines on DeiT, making our pipeline a good fit.
Additionally, we see that with our pipeline accuracy increases by $0.85$% on average.
Most models reporting higher performance using the original pipeline were trained with knowledge distillation (which we avoid to reduce bias) or using a larger image resolution (which we show is inefficient).&lt;/p&gt;
&lt;h2 id="number-of-parameters"&gt;Number of Parameters&lt;/h2&gt;
&lt;div class="callout flex items-baseline gap-2 px-3 py-2 mb-4 rounded-md border-l-4 bg-primary-50 dark:bg-primary-900/30 border-primary-500"
data-callout="note"
data-callout-metadata=""&gt;
&lt;span class="callout-icon shrink-0 translate-y-0.5 text-primary-600 dark:text-primary-400"&gt;
&lt;svg height="20" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="m11.25 11.25l.041-.02a.75.75 0 0 1 1.063.852l-.708 2.836a.75.75 0 0 0 1.063.853l.041-.021M21 12a9 9 0 1 1-18 0a9 9 0 0 1 18 0m-9-3.75h.008v.008H12z"/&gt;&lt;/svg&gt;
&lt;/span&gt;
&lt;div class="callout-content text-base dark:text-neutral-300"&gt;
&lt;div class="callout-body"&gt;&lt;p&gt;Use widescreen format for the best view of the interactive plots.&lt;/p&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;/div&gt;
&lt;p&gt;
&lt;div id="/plotly/model_vs_acc_per_param_and_acc.json" class="plotly" style="height:600px;max-width:1000px;margin: auto;"&gt;&lt;/div&gt;
&lt;script&gt;
console.log("plotting \/plotly\/model_vs_acc_per_param_and_acc.json");
Plotly.d3.json("/plotly/model_vs_acc_per_param_and_acc.json", function(err, fig) {
Plotly.plot('\/plotly\/model_vs_acc_per_param_and_acc.json', fig.data, fig.layout, {responsive: true, displayModeBar: false});
});
&lt;/script&gt;&lt;/p&gt;
&lt;p&gt;We find that in general, the accuracy per parameter goes down as models get larger.
This is especially the case with the ViT models, which are more parameter efficient than similar accuracy models at smaller sizes (ViT-Ti) and less parameter efficient for the larger models (ViT-B).
The most parameter efficient models are &lt;em&gt;Hybrid Attention&lt;/em&gt; models (EfficientFormerV2-S0, CoaT-Ti) and other &lt;em&gt;Non-attention shuffling&lt;/em&gt; models which incorporate convolutions (SwiftFormer, FastViT).&lt;/p&gt;
&lt;h2 id="speed"&gt;Speed&lt;/h2&gt;
&lt;p&gt;&lt;a name="throughput-plot"&gt;&lt;/a&gt;&lt;/p&gt;
&lt;h3 id="inference-throughput"&gt;Inference Throughput&lt;/h3&gt;
&lt;p&gt;
&lt;div id="/plotly/throughput_vs_accuracy.json" class="plotly" style="height:750px;max-width:1000px;margin: auto;"&gt;&lt;/div&gt;
&lt;script&gt;
console.log("plotting \/plotly\/throughput_vs_accuracy.json");
Plotly.d3.json("/plotly/throughput_vs_accuracy.json", function(err, fig) {
Plotly.plot('\/plotly\/throughput_vs_accuracy.json', fig.data, fig.layout, {responsive: true, displayModeBar: false});
});
&lt;/script&gt;
&lt;br&gt;
The models we evaluate often claim a superior throughput vs. accuracy trade-off compared to ViT.
However, we find that ViT remains Pareto optimal at all model sizes.
Only few models (Synthesizer-FR, NextViT, and some &lt;em&gt;Token Sequence&lt;/em&gt; models) show improvements in the Pareto front when compared to a ViT of comparable size.
We find that these observations replicate on other datasets and even when using CPUs instead of GPUs.&lt;/p&gt;
&lt;h3 id="training-speed"&gt;Training Speed&lt;/h3&gt;
&lt;p&gt;
&lt;div id="/plotly/ft_time_vs_acc.json" class="plotly" style="height:750px;max-width:1000px;margin: auto;"&gt;&lt;/div&gt;
&lt;script&gt;
console.log("plotting \/plotly\/ft_time_vs_acc.json");
Plotly.d3.json("/plotly/ft_time_vs_acc.json", function(err, fig) {
Plotly.plot('\/plotly\/ft_time_vs_acc.json', fig.data, fig.layout, {responsive: true, displayModeBar: false});
});
&lt;/script&gt;
&lt;br&gt;
This Pareto front is very similar to the one for inference time.
Here, some &lt;em&gt;Token Sequence&lt;/em&gt; models are highly efficient; in particular TokenLearner.&lt;/p&gt;
&lt;p&gt;Generally, ViT is still a solid choice for speed.&lt;/p&gt;
&lt;h2 id="memory"&gt;Memory&lt;/h2&gt;
&lt;h3 id="training-memory"&gt;Training Memory&lt;/h3&gt;
&lt;p&gt;
&lt;div id="/plotly/train_mem_vs_acc.json" class="plotly" style="height:750px;max-width:1000px;margin: auto;"&gt;&lt;/div&gt;
&lt;script&gt;
console.log("plotting \/plotly\/train_mem_vs_acc.json");
Plotly.d3.json("/plotly/train_mem_vs_acc.json", function(err, fig) {
Plotly.plot('\/plotly\/train_mem_vs_acc.json', fig.data, fig.layout, {responsive: true, displayModeBar: false});
});
&lt;/script&gt;
&lt;br&gt;
Training memory again exhibits a similar pattern as above.
There is a stark contrast between models using low-resolution and high-resolution images as the ones with high-resolution images need significantly more memory with not that much accuracy gained.&lt;/p&gt;
&lt;h3 id="inference-memory"&gt;Inference Memory&lt;/h3&gt;
&lt;p&gt;
&lt;div id="/plotly/inf_mem_vs_acc.json" class="plotly" style="height:750px;max-width:1000px;margin: auto;"&gt;&lt;/div&gt;
&lt;script&gt;
console.log("plotting \/plotly\/inf_mem_vs_acc.json");
Plotly.d3.json("/plotly/inf_mem_vs_acc.json", function(err, fig) {
Plotly.plot('\/plotly\/inf_mem_vs_acc.json', fig.data, fig.layout, {responsive: true, displayModeBar: false});
});
&lt;/script&gt;
&lt;br&gt;
The Pareto front of inference memory is the most different to all the others.
It is the only one where ViT is &lt;em&gt;not&lt;/em&gt; Pareto optimal.
Instead &lt;em&gt;Hybrid Attention&lt;/em&gt; and convolution based models excel, similar to the
.
It is also the only setup where a model (EviT) using 384px resolution images is Pareto optimal.&lt;/p&gt;
&lt;h2 id="scaling-behaviors"&gt;Scaling Behaviors&lt;/h2&gt;
&lt;p&gt;Our observations reveal that fine-tuning at a higher resolution is inefficient.
While it may result in improved accuracy, it entails a significant increase in computational cost, leading to a substantial reduction in throughput.
In turn, scaling up the model ends up being more efficient.
This can be seen when comparing the corresponding Pareto fronts for
,
, and
.&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;A few examples for scaling the model vs. scaling the image size:&lt;/strong&gt;
&lt;div id="/plotly/scaling_behavior.json" class="plotly" style="height:350px;max-width:1000px;margin: auto;"&gt;&lt;/div&gt;
&lt;script&gt;
console.log("plotting \/plotly\/scaling_behavior.json");
Plotly.d3.json("/plotly/scaling_behavior.json", function(err, fig) {
Plotly.plot('\/plotly\/scaling_behavior.json', fig.data, fig.layout, {responsive: true, displayModeBar: false});
});
&lt;/script&gt;
We see that scaling up the model size is always more efficient than scaling up the image resolution.&lt;/p&gt;
&lt;h2 id="correlation-of-metrics"&gt;Correlation of Metrics&lt;/h2&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th style="text-align: left"&gt;$\text{corr}(x, y)$&lt;/th&gt;
&lt;th style="text-align: center"&gt;Params&lt;/th&gt;
&lt;th style="text-align: center"&gt;Training&lt;br&gt;Time&lt;/th&gt;
&lt;th style="text-align: center"&gt;Training&lt;br&gt;Memory&lt;/th&gt;
&lt;th style="text-align: center"&gt;Inference&lt;br&gt;Time&lt;/th&gt;
&lt;th style="text-align: center"&gt;Inference&lt;br&gt;Memory&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;FLOPS&lt;/td&gt;
&lt;td style="text-align: center"&gt;0.30&lt;/td&gt;
&lt;td style="text-align: center"&gt;0.72&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;0.85&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;0.48&lt;/td&gt;
&lt;td style="text-align: center"&gt;0.42&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;Params&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;0.05&lt;/td&gt;
&lt;td style="text-align: center"&gt;0.18&lt;/td&gt;
&lt;td style="text-align: center"&gt;0.02&lt;/td&gt;
&lt;td style="text-align: center"&gt;0.40&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;Training Time&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;0.89&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;0.81&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;0.17&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;Training Memory&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;0.71&lt;/td&gt;
&lt;td style="text-align: center"&gt;0.48&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;Inference Time&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;0.13&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;p&gt;The highest correlation of 0.89 is between fine-tuning time and training memory.
This suggests a common underlying factor or bottleneck, possibly related to the necessity of memory reads during training.
We find a reliability of estimating computational costs only based on theoretical metrics, like [
,
] before.
Consequently, assessing model efficiency in practice requires the empirical measurement of throughput and memory requirements.&lt;/p&gt;
&lt;h1 id="tldr-which-transformer-to-favor"&gt;TlDr: Which Transformer to Favor?&lt;/h1&gt;
&lt;p&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="Flowchart to answer the question: Which Transformer to Favor"
srcset="https://nauen-it.de/publications/wtf-benchmark/images/wtf_flowchart_hu_dee4becf42a4ecd6.webp 320w, https://nauen-it.de/publications/wtf-benchmark/images/wtf_flowchart_hu_19da909329796f5f.webp 480w, https://nauen-it.de/publications/wtf-benchmark/images/wtf_flowchart_hu_7c33393704142ba4.webp 760w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://nauen-it.de/publications/wtf-benchmark/images/wtf_flowchart_hu_dee4becf42a4ecd6.webp"
width="760"
height="172"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/p&gt;
&lt;p&gt;Our benchmark offers actionable insights for answering the question of &lt;em&gt;which transformer to favor&lt;/em&gt; in the form of models and strategies to use.
We have compiled an overview of these in the flowchart above.
ViT remains the preferred choice overall.
However, Token Sequence methods can become viable alternatives when speed and training efficiency are of importance.
For scenarios with significant inference memory constraints, considering Hybrid CNN-attention models can prove advantageous.&lt;/p&gt;
&lt;p&gt;We additionally find that it is much more efficient to scale up the model size than to scale up the image resolution.
This goes against the trend of efficient models being evaluated using higher resolution images, which cancels out possible efficiency gains.&lt;/p&gt;
&lt;h1 id="references"&gt;References&lt;/h1&gt;
&lt;p&gt;For references and links to the efficient transformer models, see the
.&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;&lt;a name="ce-dl"&gt;Brian R Bartoldson, Bhavya Kailkhura, and Davis Blalock&lt;/a&gt;. &lt;em&gt;Compute-efficient deep learning: Algorithmic trends and opportunities&lt;/em&gt;. Journal of Machine Learning Research, 24(122):1–77, 2023.&lt;/li&gt;
&lt;li&gt;&lt;a name="eff-misnomer"&gt;Mostafa Dehghani, Yi Tay, Anurag Arnab, Lucas Beyer, and Ashish Vaswani&lt;/a&gt;. &lt;em&gt;The efficiency misnomer&lt;/em&gt;. In International Conference on Learning Representations, 2022.&lt;/li&gt;
&lt;li&gt;&lt;a name="ImageNet"&gt;Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and Li Fei-Fei&lt;/a&gt;. &lt;em&gt;ImageNet: A large-scale hierarchical image database&lt;/em&gt;. In 2009 IEEE Conference on Computer Vision and Pattern Recognition. IEEE, 2009.&lt;/li&gt;
&lt;li&gt;&lt;a name="ViT"&gt;Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, and Neil Houlsby&lt;/a&gt;. &lt;em&gt;An image is worth 16x16 words: Transformers for image recognition at scale&lt;/em&gt;. In 9th International Conference on Learning Rep- resentations, ICLR 2021, Virtual Event, Austria, May 3-7, 2021. OpenReview.net, 2021.&lt;/li&gt;
&lt;li&gt;&lt;a name="DeiT"&gt;Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre Sablayrolles, and Herve Jegou&lt;/a&gt;. &lt;em&gt;Training data-efficient image transformers &amp;amp; distillation through attention&lt;/em&gt;. In Marina Meila and Tong Zhang, editors, Proceedings of the 38th International Conference on Machine Learning, volume 139 of Proceedings of Machine Learning Research, pages 10347–10357. PMLR, 7 2021.&lt;/li&gt;
&lt;li&gt;&lt;a name="DeiT-III"&gt;Hugo Touvron, Matthieu Cord, and Hervé Jégou&lt;/a&gt;. &lt;em&gt;Deit iii: Revenge of the vit&lt;/em&gt;. In Shai Avidan, Gabriel Brostow, Moustapha Cissé, Giovanni Maria Farinella, and Tal Hassner, editors, Computer Vision – ECCV 2022, pages 516–533, Cham, 2022. Springer Nature Switzerland.&lt;/li&gt;
&lt;li&gt;&lt;a name="Transformer"&gt;Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin&lt;/a&gt;. &lt;em&gt;Attention is all you need&lt;/em&gt;. In I. Guyon, U. Von Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett, editors, Advances in Neural Information Processing Systems, volume 30. Curran Associates, Inc., 2017.&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;&lt;strong&gt;Associated Projects:&lt;/strong&gt;
,
,
&lt;/p&gt;</description></item><item><title>A Study in Dataset Distillation for Image Super-Resolution</title><link>https://nauen-it.de/publications/dataset-distillation-sr/</link><pubDate>Wed, 05 Feb 2025 13:00:00 +0000</pubDate><guid>https://nauen-it.de/publications/dataset-distillation-sr/</guid><description>
&lt;div class="callout flex items-baseline gap-2 px-3 py-2 mb-4 rounded-md border-l-4 bg-primary-50 dark:bg-primary-900/30 border-primary-500"
data-callout="note"
data-callout-metadata=""&gt;
&lt;span class="callout-icon shrink-0 translate-y-0.5 text-primary-600 dark:text-primary-400"&gt;
&lt;svg height="20" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="m11.25 11.25l.041-.02a.75.75 0 0 1 1.063.852l-.708 2.836a.75.75 0 0 0 1.063.853l.041-.021M21 12a9 9 0 1 1-18 0a9 9 0 0 1 18 0m-9-3.75h.008v.008H12z"/&gt;&lt;/svg&gt;
&lt;/span&gt;
&lt;div class="callout-content text-base dark:text-neutral-300"&gt;
&lt;div class="callout-body"&gt;&lt;p&gt;For more information, see the
.&lt;/p&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;/div&gt;
&lt;p&gt;&lt;strong&gt;Associated Projects:&lt;/strong&gt;
,
&lt;/p&gt;</description></item><item><title>TaylorShift: Shifting the Complexity of Self-Attention from Squared to Linear (and Back) using Taylor-Softmax</title><link>https://nauen-it.de/publications/taylor-shift/</link><pubDate>Tue, 03 Dec 2024 13:00:00 +0000</pubDate><guid>https://nauen-it.de/publications/taylor-shift/</guid><description>&lt;p&gt;&lt;a name="Fig1"&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="Operations involved in TaylorShift"
srcset="https://nauen-it.de/publications/taylor-shift/images/operations_hu_e0764c064eb9493a.webp 320w, https://nauen-it.de/publications/taylor-shift/images/operations_hu_3f8ec5ea50ff7396.webp 480w, https://nauen-it.de/publications/taylor-shift/images/operations_hu_658eb6c14adbfdf8.webp 760w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://nauen-it.de/publications/taylor-shift/images/operations_hu_e0764c064eb9493a.webp"
width="760"
height="250"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/a&gt;&lt;/p&gt;
&lt;h1 id="introduction"&gt;Introduction&lt;/h1&gt;
&lt;p&gt;Despite their remarkable success, Transformers face a significant challenge when dealing with long sequences due to the quadratic complexity of the attention mechanism.
This limitation hinders their application to tasks involving extensive contextual information, such as processing long documents or high-resolution images.
While various approaches have been proposed to address this issue, they often sacrifice accuracy, specialize in specific domains, or neglect individual token-to-token interactions.
To overcome these limitations, we introduce &lt;em&gt;TaylorShift&lt;/em&gt;, a novel method that reformulates the softmax function in the attention mechanism using the Taylor approximation of the exponential.
By combining this approximation with a tensor-product-based operator, &lt;em&gt;TaylorShift&lt;/em&gt; achieves linear-time complexity while preserving the essential token-to-token interactions.
We analyze the efficiency of &lt;em&gt;TaylorShift&lt;/em&gt; in depth, both analytically and empirically and find that it outperforms the standard transformer architecture in 4 out of 5 tasks.&lt;/p&gt;
&lt;h1 id="how-does-taylorshift-work"&gt;How does &lt;em&gt;TaylorShift&lt;/em&gt; work?&lt;/h1&gt;
&lt;p&gt;Essentially, &lt;em&gt;TaylorShift&lt;/em&gt; works by replacing the exponential function in the softmax by its
. For a vector $\mathbf x = [x\_1, ..., x\_m] = [x_i]_{i = 1}^m$:
&lt;/p&gt;
$$
\text{softmax}(x) = \left[\frac{\exp(x_i)}{\sum\_{j} \exp(x_j)}\right]\_{i = 1}^m \approx \left[ \frac{\frac{x_i^2}{2} + x_i + 1}{\sum_j \frac{x_j^2}{2} + x_j + 1} \right]\_{i = 1}^m = \text{T-SM}(x)
$$&lt;h2 id="direct-taylorshift"&gt;Direct TaylorShift&lt;/h2&gt;
&lt;p&gt;We call the direct implementation of the
using the Taylor Softmax &lt;em&gt;direct-TaylorShift&lt;/em&gt;, as seen
. For queries $Q$, keys $K$, and values $V$, this becomes:
&lt;/p&gt;
$$
Y = \text{T-SM}(Q K^\top) V
$$&lt;h2 id="efficient-taylorshift"&gt;Efficient TaylorShift&lt;/h2&gt;
&lt;p&gt;&lt;em&gt;Direct-TaylorShift&lt;/em&gt; has the same scaling behavior as standard attention.
However, we can reduce its computational complexity from $\mathcal O(N^2 d)$ to $\mathcal O(N d^3)$ by reordering the operations internally.
This becomes useful for long sequences, where $N \gg d$.&lt;/p&gt;
&lt;p&gt;Let me first introduce a tensor-product-based operator:
&lt;/p&gt;
$$
\boxtimes: \mathbb R^{N \times d} \times \mathbb R^{N \times d} \to \mathbb R^{N \times d^2}.
$$&lt;p&gt;
Basically, we take two lists of $d$-dimensional vectors $[a\_i \in \mathbb R^d]\_i$ and $[b\_i \in \mathbb R^i]\_i$ and for each index $i$ we multiply each element of $a_i$ with all the elements of $b_i$.
The result is $d^2$ dimensional, since that is the number of possible combinations.
We also write $A^{\boxtimes 2} := A \boxtimes A$.&lt;/p&gt;
&lt;details&gt;
&lt;summary&gt;Mathematical Details&lt;/summary&gt;
In mathematical terms, we define
$$
[A \boxtimes B]_n = \iota(A_n \otimes B_n) \in \mathbb R^{d^2} \hspace{10pt} \forall n=1,
..., N
$$
Here, $A_n$, $B_n$, and $[A \boxtimes B]_n$ is the $n$-th entry of the respective matrix. $\otimes$ is the tensor product (or &lt;a href="https://en.wikipedia.org/wiki/Outer_product" target="_blank"&gt;outer product&lt;/a&gt;) of two $d$-dimensional vectors and $\iota: \mathbb R^{d \times d} \to \mathbb R^{d^2}$ is the &lt;i&gt;canonical isomorphism&lt;/i&gt; (basically, it just reorders the entries of a matrix into a vector; the exact order does not matter, as long as it's always the same one).
&lt;/details&gt;
&lt;p&gt;It turns out, that by using this operator, we can calculate &lt;em&gt;TaylorShift&lt;/em&gt; more efficiently:
&lt;/p&gt;
$$
Y = Y_\text{nom} \oslash Y_\text{denom} = \left[ \frac{[Y_\text{nom}]\_{i, :}}{[Y_\text{denom}]\_i} \right]\_{i = 1}^N
$$&lt;p&gt;
with
&lt;/p&gt;
$$
Y_\text{nom} = \frac 1 2 Q^{\boxtimes 2} \left( (K^{\boxtimes 2})^\top V \right) + Q (K^\top V) + \sum_\text{columns} V.
$$&lt;p&gt;
$Y_\text{denom}$ is the same, but with $\mathbb 1 = [1, ..., 1]$ instead of $V$.&lt;/p&gt;
&lt;details&gt;
&lt;summary&gt;Mathematical Details&lt;/summary&gt;
We have
$$
Y_\text{nom} = \frac 1 2 (Q K^\top)^{\odot 2} V + Q K^\top V + \sum_\text{columns} V.
$$
Let $ \pi: \{1, .., d\} \times \{1, ..., d\} \to \{1, ..., d^2\} $ be the map that describes the reordering that $\iota$ (defined in the &lt;i&gt;Mathematical Details&lt;/i&gt; section above) does.
Then we have
$$
\left[ A^{\boxtimes 2} \right]_{n, \pi(k, \ell)} = (A_n \otimes A_n)_{k, \ell} = A_{n, k} A_{n, \ell}.
$$
This allows us to linearize the squared term $(Q K^\top)^{\odot 2} V$ by using $\boxtimes$ to unroll the square of a sum along a sum of $d^2$ elements:
$$
\begin{align*}
\left[(QK^\top)^{\odot 2} \right]_{i, j} =&amp; \left( \sum_{k = 1}^d Q_{ik} K_{jk} \right)^2 \\
=&amp; \sum_{k, \ell = 1}^d Q_{ik} Q_{i\ell} K_{jk} K_{j \ell} \\
=&amp; \sum_{k, \ell = 1}^d \left[ Q^{\boxtimes 2} \right]_{i, \pi(k, \ell)} \left[ K^{\boxtimes 2} \right]_{j, \pi(k, \ell)} \\
=&amp; \left[ Q^{\boxtimes 2} \right]_i \left[ K^{\boxtimes 2} \right]_j^\top
\end{align*}
$$
Therefore
$$
(QK\top)^{\odot 2} V = Q^{\boxtimes 2} (K^{\boxtimes 2})^\top V,
$$
which can be computed in $\mathcal O(N d^3)$ by multiplying from right to left.
We can also calculate $Y_\text{nom}$ and $Y_\text{denom}$ at once by setting $V \gets V \circ \mathbb 1$.
&lt;/details&gt;
&lt;h2 id="normalization"&gt;Normalization&lt;/h2&gt;
&lt;p&gt;We found that some intermediate results of &lt;em&gt;TaylorShift&lt;/em&gt; tended to have very large norms, which ultimately led to training failures.
We introduce the following three steps for normalization:&lt;/p&gt;
&lt;ol&gt;
&lt;li&gt;Normalize the queries and keys to one and introduce an additional attention temperature parameter (per attention-head) $\tau \in \mathbb R$:
$$
q_i \gets \frac{\tau q_i}{||q_i||_2}, \hspace{10pt} k_i \gets \frac{k_i}{||k_i||_2} \hspace{10pt} \forall i=1, ..., N
$$&lt;/li&gt;
&lt;li&gt;Counteract the scaling behaviors by multiplying $Q$ and $K$ by $\sqrt[4]{d}$ and $V$ by $\frac 1 N$.&lt;/li&gt;
&lt;li&gt;Normalize the output by multiplying by $\sqrt{\frac N d}$.&lt;/li&gt;
&lt;/ol&gt;
&lt;details&gt;
&lt;summary&gt;Scaling Behavior Details&lt;/summary&gt;
Experimentally, we find the following approximate mean sizes for intermediate results with $Q, K,$ and $V$ sampled uniformly from the unit sphere:
&lt;div class="scroll-container"&gt;&lt;table class="table"&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th style="text-align: left"&gt;Interm. Expr.&lt;/th&gt;
&lt;th style="text-align: center"&gt;$(K^{\boxtimes 2})^\top V$&lt;/th&gt;
&lt;th style="text-align: center"&gt;$(QK^\top)^{\odot 2} V$&lt;/th&gt;
&lt;th style="text-align: center"&gt;$ QK^\top V$&lt;/th&gt;
&lt;th style="text-align: center"&gt;$Y_\text{denom}$&lt;/th&gt;
&lt;th style="text-align: center"&gt;$Y$&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;Size ($\approx$)&lt;/td&gt;
&lt;td style="text-align: center"&gt;$\frac{N}{\sqrt d}$&lt;/td&gt;
&lt;td style="text-align: center"&gt;$\frac N d$&lt;/td&gt;
&lt;td style="text-align: center"&gt;$\sqrt N (1 + \frac{1}{4d})$&lt;/td&gt;
&lt;td style="text-align: center"&gt;$N (2 + \frac{1}{d})$&lt;/td&gt;
&lt;td style="text-align: center"&gt;$\sqrt{\frac d N}$&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;Size after Normalization ($\approx$)&lt;/td&gt;
&lt;td style="text-align: center"&gt;$1$&lt;/td&gt;
&lt;td style="text-align: center"&gt;$1$&lt;/td&gt;
&lt;td style="text-align: center"&gt;$\frac{1}{\sqrt{Nd}} (1 + \frac{1}{4d})$&lt;/td&gt;
&lt;td style="text-align: center"&gt;$2 + \frac{1}{d}$&lt;/td&gt;
&lt;td style="text-align: center"&gt;$1$&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;/div&gt;
&lt;/details&gt;
&lt;h4 id="efficient-taylorshift-algorithm"&gt;&lt;em&gt;Efficient-TaylorShift&lt;/em&gt; Algorithm&lt;/h4&gt;
&lt;p&gt;We compile all the information into the pseudocode for &lt;em&gt;efficient-TaylorShift&lt;/em&gt;:
&lt;a name="Algorithm"&gt;
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="Normalized efficient-TaylorShift algorithm"
srcset="https://nauen-it.de/publications/taylor-shift/images/algorithm_hu_1bdb06ab9fdb7082.webp 320w, https://nauen-it.de/publications/taylor-shift/images/algorithm_hu_d32c253a9845a9dc.webp 480w, https://nauen-it.de/publications/taylor-shift/images/algorithm_hu_c00c8dd319ba7479.webp 760w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://nauen-it.de/publications/taylor-shift/images/algorithm_hu_1bdb06ab9fdb7082.webp"
width="760"
height="442"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
&lt;/a&gt;&lt;/p&gt;
&lt;p&gt;Find the PyTorch implementation
.&lt;/p&gt;
&lt;!-- &lt;details&gt;
&lt;summary&gt;Python Code&lt;/summary&gt;
```Python
def box_tensor(a: torch.Tensor, b: torch.Tensor) -&gt; torch.Tensor:
"""Calculate a ⊠ b.
Args:
a (torch.Tensor): Tensor of shape (..., N, d)
b (torch.Tensor): Tensor of shape (..., N, d)
Returns:
torch.Tensor: Tensor of shape (..., N, d^2)
"""
return (a.unsqueeze(-1) * b.unsqueeze(-2)).view(list(a.shape)[:-1] + [-1]) # uses the broadcasted Hadamard product
def efficient_tayor_shift_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, temperature=1.) -&gt; torch.Tensor:
"""Calculate efficient-TaylorShift attention.
Tensor shapes:
B: batch size
H: number of attention heads
N: sequence length
d: per-head dimension
Args:
q (torch.Tensor): Tensor of shape (B, H, N, d)
k (torch.Tensor): Tensor of shape (B, H, N, d)
v (torch.Tensor): Tensor of shape (B, H, N, d)
temperature (optional): Attention temperature. Can be learnable.
Returns:
torch.Tensor: Result of shape (B, H, N, d)
"""
B, H, N, d = q.shape
# The extra dimension is for calculating the nominator and denominator at once.
# The factor sqrt(d / N) in the denominator is for output normalization.
v = torch.cat([math.sqrt(d / N) * torch.ones(*v.shape[:-1], 1, device=v.device, dtype=v.dtype), v], dim=-1)
# normalize Q and K
scale = d**.25
q = torch.nn.functional.normalize(q, dim=-1) * (scale * temperature)
k = torch.nn.functional.normalize(k, dim=-1) * scale
# normalize V
v = v / N
kv_mod = box_tensor(k, k).transpose(-1, -2) @ v
y = (
.5 * box_tensor(q, q) @ kv_mod
+ scale**2 * q @ (k.transpose(-1, -2) @ v)
+ scale**4 * v.sum(dim=-2).unsqueeze(-2)
)
y_norm, y = y[:, :, :, :1], y[:, :, :, 1:] # split y_denom=y_norm and y_nom=y
return y / y_norm # normalization
def direct_tayor_shift_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, temperature=1.) -&gt; torch.Tensor:
"""Calculate direct-TaylorShift attention.
Tensor shapes:
B: batch size
H: number of attention heads
N: sequence length
d: per-head dimension
Args:
q (torch.Tensor): Tensor of shape (B, H, N, d)
k (torch.Tensor): Tensor of shape (B, H, N, d)
v (torch.Tensor): Tensor of shape (B, H, N, d)
temperature (optional): Attention temperature. Can be learnable.
Returns:
torch.Tensor: Result of shape (B, H, N, d)
"""
# Normalize Q and K
q = torch.nn.functional.normalize(q, dim=-1) * temperature
k = torch.nn.functional.normalize(k, dim=-1)
attn = q @ k.transpose(-2, -1)
# Normalize possible large intermediate values
attn_max = attn.abs().max(dim=-1).values.view(*attn.shape[:-1], 1)
max_val = .5 * attn_max.square() + attn_max + 1
attn = .5 * (attn / max_val.sqrt()).square() + attn / max_val + 1 / max_val
attn = attn / attn.sum(-1).view(*attn.shape[:-1], 1)
# Output normalization
v *= sqrt(N / d)
return attn @ v
```
&lt;/details&gt; --&gt;
&lt;h1 id="how-efficient-is-efficient-taylorshift"&gt;How efficient is &lt;em&gt;efficient-TaylorShift&lt;/em&gt;?&lt;/h1&gt;
&lt;p&gt;We analyze the circumstances where &lt;em&gt;efficient-TaylorShift&lt;/em&gt; is more efficient than &lt;em&gt;direct-TaylorShift&lt;/em&gt; in terms of speed, based on the number of floating point operations, and memory, based on the size of intermediate results.&lt;/p&gt;
&lt;h2 id="floating-point-operations"&gt;Floating Point Operations&lt;/h2&gt;
&lt;p&gt;The number of floating point operations for &lt;em&gt;direct-TaylorShift&lt;/em&gt; and &lt;em&gt;efficient-TaylorShift&lt;/em&gt; is
&lt;/p&gt;
$$\text{ops}_\text{dir} = 4N^2 d + 6 N^2,$$&lt;p&gt;
&lt;/p&gt;
$$\text{ops}\_\text{eff} = N (4d^3 + 10d^2 + 9d + 4).$$&lt;p&gt;Therefore, there exists an $N_0 = N_0(d)$, such that &lt;em&gt;efficient-TaylorShift&lt;/em&gt; is more efficient for all $N &gt; N_0$.
We calculate
&lt;/p&gt;
$$
N_0 = \frac{4d^3 + 10d^2 + 9d + 4}{4d + 6} \leq d^2 + d + \frac 3 4.
$$&lt;details&gt;
&lt;summary&gt;Mathematical Details&lt;/summary&gt;
We need the following operations:
&lt;p&gt;&lt;em&gt;direct-TaylorShift&lt;/em&gt;:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;$2N^2 d$ for the multiplication of $QK^\top$,&lt;/li&gt;
&lt;li&gt;$4N^2$ operations to apply $x \mapsto \frac 1 2 x^2 + x + 1$ element-wise to that matrix,&lt;/li&gt;
&lt;li&gt;$2N^2$ operations for normalization,&lt;/li&gt;
&lt;li&gt;$2N^2 d$ operations for the final multiplication by $V$
$$
\Rightarrow \text{ops}_\text{dir} = 4 N^2 d + 6 N^2
$$&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;&lt;em&gt;efficient-TaylorShift&lt;/em&gt;:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;$2N d^2$ operations for $K^{\boxtimes 2}$ and $Q^{\boxtimes 2}$,&lt;/li&gt;
&lt;li&gt;$2 N d^2 (d + 1)$ operations to multiply by $V \in \mathbb R^{N \times (d+1)}$ and get $(K^{\boxtimes 2})^\top V$,&lt;/li&gt;
&lt;li&gt;$2 N d^2 (d + 1)$ operations for the final multiplication by $Q^{\boxtimes 2}$,&lt;/li&gt;
&lt;li&gt;$4 N d (d + 1)$ operations for computing $Q K^\top V$ from right to left,&lt;/li&gt;
&lt;li&gt;$N (d + 1)$ operations for summing up the columns of $V$,&lt;/li&gt;
&lt;li&gt;$3 N (d + 1)$ operations for the sums and scalar multiplication, and finally&lt;/li&gt;
&lt;li&gt;$N d$ operations for normalization.
$$
\Rightarrow \text{ops}_\text{eff} = N (2 d^2 + 4 d^2 (d + 1) + 4 d (d + 1) + 4 (d + 1) + d)
$$&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;We derive $N_0$ by setting $\text{ops}\_\text{dir} \stackrel{!}{=} \text{ops}\_\text{eff}$, which is equivalent to
&lt;/p&gt;
$$
N_0 = \frac{4d^3 + 10d^2 + 9d + 4}{4d + 6} \leq \frac{4d^3 + 6d^2}{4d + 6} + \frac{4d^2 + 6d}{4d + 6} + \frac{3d + 4.5}{4d + 6} = d^2 + d + \frac 3 4
$$&lt;/details&gt;
&lt;h2 id="size-of-intermediate-results"&gt;Size of intermediate Results&lt;/h2&gt;
&lt;p&gt;The size of the largest intermediate results needed at once for &lt;em&gt;direct-&lt;/em&gt; and &lt;em&gt;efficient-TaylorShift&lt;/em&gt; is
&lt;/p&gt;
$$\text{entries}_\text{dir} = \underbrace{dN}\_{\text{for } V} + \underbrace{2N^2}\_{\text{for } QK^\top \text{ and result}},$$&lt;p&gt;
&lt;/p&gt;
$$\text{entries}\_\text{eff} = d^2(d+1) + 2dN + (d+1)N + d^2N.$$&lt;p&gt;We can again find $N_1 = N_1(d)$, such that &lt;em&gt;efficient-TaylorShift&lt;/em&gt; is more memory efficient for all $N &gt; N_1$.
We find
&lt;/p&gt;
$$
N_1 = \frac 1 4 \left[ d^2 + 2 d + 1 + \sqrt{d^4 + 12 d^3 + 14 d^2 + 4d + 1} \right] \leq \frac 1 2 d^2 + 2 d + \frac 1 2.
$$&lt;details&gt;
&lt;summary&gt;Mathematical Details&lt;/summary&gt;
We count the number of entries in the largest intermediate results needed at once.
&lt;p&gt;For &lt;em&gt;direct-TaylorShift&lt;/em&gt; we need the largest intermediate memory when calculating $\text{T-SM}(QK^\top)$ from $QK^\top$.&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;$d N$ entries of $V$&lt;/li&gt;
&lt;li&gt;$N^2$ entries of $QK^\top$&lt;/li&gt;
&lt;li&gt;$N^2$ entries for the result.
Note that we can&amp;rsquo;t simply reuse the memory from $QK^\top$, because we need to save at least one intermediate result when calculating $\frac 1 2 x^2 + x$.&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;For &lt;em&gt;efficient-TaylorShift&lt;/em&gt; we need the most memory when calculating $(K^{\boxtimes 2})^\top V$:&lt;/p&gt;
&lt;ul&gt;
&lt;li&gt;$2 N d$ entries for $Q,$ and $K$ for later&lt;/li&gt;
&lt;li&gt;$N (d + 1)$ entries for $V$ (also needed again later)&lt;/li&gt;
&lt;li&gt;$N d^2$ entries of $K^{\boxtimes 2}$&lt;/li&gt;
&lt;li&gt;$d^2 (d + 1)$ entries for the result&lt;/li&gt;
&lt;/ul&gt;
&lt;p&gt;We again derive $N_1$ by setting $\text{entries}\_\text{dir} \stackrel{!}{=} \text{entries}\_\text{eff}$ for $N_1$.
Therefore
&lt;/p&gt;
$$
N_1^2 - \frac{d^2 + 2d + 1}{2} N_1 - \frac{d^3 + d^2}{2} = 0
$$&lt;p&gt;
The larger of the two solutions is
&lt;/p&gt;
$$
\begin{align*}
N_1 =&amp; \frac 1 4 \left[ d^2 + 2d + 1 + \sqrt{(d^2 + 2d + 1)^2 + 8(d^3 + d^2)} \right] \\\\
=&amp; \frac 1 4 \left[ d^2 + 2d + 1 + \sqrt{d^4 + 12 d^3 + 14 d^2 + 4d + 1} \right].
\end{align*}
$$&lt;p&gt;
Since
&lt;/p&gt;
$$
(d^2 + 6d + 1)^2 = d^4 + 12d^3 + 38 d^2 + 12 d + 1 \geq d^4 + 12 d^3 + 14 d^2 + 4d + 1
$$&lt;p&gt;
we have
&lt;/p&gt;
$$
N_1 \leq \frac 1 2 d^2 + 2 d + \frac 1 2.
$$&lt;/details&gt;
&lt;details&gt;
&lt;summary&gt;$N_0$ and $N_1$ for typical values of $d$&lt;/summary&gt;
&lt;p&gt;&lt;strong&gt;Table:&lt;/strong&gt;
&lt;div class="scroll-container"&gt;&lt;table class="table"&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th style="text-align: left"&gt;d&lt;/th&gt;
&lt;th style="text-align: right"&gt;8&lt;/th&gt;
&lt;th style="text-align: right"&gt;16&lt;/th&gt;
&lt;th style="text-align: right"&gt;32&lt;/th&gt;
&lt;th style="text-align: right"&gt;64&lt;/th&gt;
&lt;th style="text-align: right"&gt;128&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;$N_0$&lt;/td&gt;
&lt;td style="text-align: right"&gt;73&lt;/td&gt;
&lt;td style="text-align: right"&gt;273&lt;/td&gt;
&lt;td style="text-align: right"&gt;1057&lt;/td&gt;
&lt;td style="text-align: right"&gt;4161&lt;/td&gt;
&lt;td style="text-align: right"&gt;16513&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;$N_1$&lt;/td&gt;
&lt;td style="text-align: right"&gt;47&lt;/td&gt;
&lt;td style="text-align: right"&gt;159&lt;/td&gt;
&lt;td style="text-align: right"&gt;574&lt;/td&gt;
&lt;td style="text-align: right"&gt;2174&lt;/td&gt;
&lt;td style="text-align: right"&gt;8446&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;/div&gt;
&lt;/p&gt;
&lt;p&gt;&lt;strong&gt;Calculator:&lt;/strong&gt;&lt;/p&gt;
&lt;html&gt;
&lt;div style="display: flex; background: AliceBlue; width: fit-content; padding: 10px; border-radius: 20px; height: 40pt; margin-left: auto; margin-right: auto;"&gt;
&lt;p style="color: black; font-size: 20pt; font-family: MJXZERO, MJXTEX; padding-right: 5px;"&gt;&lt;i&gt;d&lt;/i&gt; =&lt;/p&gt;&lt;input id="d" type="number" onChange="MyFunction();" onkeyup="MyFunction();" min=1 size="4" value="32" style="background-color: #fbfcfc; color: black;"&gt;
&lt;p style="padding-left: 20px; color: black; font-size: 20pt; font-family: MJXZERO, MJXTEX;" id="f"&gt;=&gt; N_0 = 1057 N_1 = 577&lt;/p&gt;
&lt;/div&gt;
&lt;script&gt;
function MyFunction(){
var x = Number(document.getElementById("d").value);
var n0 = Math.ceil((4 * x**3 + 10 * x**2 + 9 * x + 4) / (4 * x + 6))
var n1 = Math.ceil(.25 * (x**2 + 2 * x + 1 + Math.sqrt(x**4 + 12 * x**3 + 14 * x**2 + 4 * x + 1)))
document.getElementById("f").innerHTML = "&amp;rArr; &lt;i&gt;N&lt;sub&gt;0&lt;/sub&gt;&lt;/i&gt; = &lt;i&gt;" + n0 + "&amp;nbsp;&amp;nbsp;&amp;nbsp; N&lt;sub&gt;1&lt;/sub&gt; &lt;/i&gt;=&lt;i&gt; " + n1 + "&lt;/i&gt;"
}
MyFunction();
&lt;/script&gt;
&lt;/html&gt;
&lt;/details&gt;
&lt;h1 id="how-can-we-increase-the-efficiency"&gt;How can we increase the efficiency?&lt;/h1&gt;
&lt;p&gt;In an effort to increase the efficiency while processing the same number of tokens $N$, one might opt to reduce the embedding dimension $d_\text{emb}$.
However, this might come at the cost of expressiveness.
Given that &lt;em&gt;efficient-TaylorShift&lt;/em&gt; scales with $\mathcal O(Nd^3)$, we can instead increase the number of attention heads $h$ with dimension $d = \frac{d_\text{emb}}{h}$ each.
We find that the number of operations is
&lt;/p&gt;
$$
\text{ops}\_\text{eff}(\text{MHSA}) = N \left( 4 \frac{d\_\text{emb}^3}{h^2} + 10 \frac{d\_\text{emb}^2}{h} + 9 d\_\text{emb} + 4h \right)
$$&lt;p&gt;
and the number of entries is
&lt;/p&gt;
$$
\text{entries}\_\text{eff}(\text{MHSA}) = \frac{d\_\text{emb}^3}{h^2} + (N + 1) \frac{d\_\text{emb}^2}{h} + 3N d\_\text{emb} + N h,
$$&lt;p&gt;
which are both strictly decreasing in $h$.
Therefore, &lt;em&gt;efficient-TaylorShift&lt;/em&gt; becomes faster and needs less memory with more attention heads.&lt;/p&gt;
&lt;details&gt;
&lt;summary&gt;Mathematical Details&lt;/summary&gt;
We identify the extreme points of both (as functions of $h$) by setting their derivatives to zero:
$$
\frac{\partial}{\partial h} \text{ops}_\text{eff}(\text{MHSA}) = -8 \frac{d_\text{emb}^3}{h^3} - 10 \frac{d_\text{emb}^2}{h^2} + 4
$$
By setting $d = \frac{d_\text{emb}}{h}$, we find that the above is zero at $d \approx 0.52$.
This would imply $h = \frac{1}{0.52} d_\text{emb}$, but the maximum value for $h$ is $d_\text{emb}$, since the number of dimensions $d$ has to be an integer.
&lt;p&gt;Similarly, for the number of entries, we find:
&lt;/p&gt;
$$
\frac{\partial}{\partial h} \text{entries}\_\text{eff}(\text{MHSA}) = -2 d^2 - (N + 1) d + N \stackrel{!}{=} 0
$$&lt;p&gt;
&lt;/p&gt;
$$
\Leftrightarrow N = (2d + N + 1) d^2 \stackrel{d &gt; 0}{\geq} (N + 1) d^2
$$&lt;p&gt;
Therefore $1 &gt; \frac{N}{N+1} \geq d^2$ which would imply $1 &gt; d$ and therefore $h &gt; d_\text{emb}$ again, but the maximum value possible is $h = d_\text{emb}$.&lt;/p&gt;
&lt;/details&gt;
&lt;h1 id="empirical-evaluation"&gt;Empirical Evaluation&lt;/h1&gt;
&lt;h2 id="efficiency-of-taylorshift-attention"&gt;Efficiency of TaylorShift Attention&lt;/h2&gt;
&lt;p&gt;We first validate our theoretical analysis on the efficiency of &lt;em&gt;TaylorShift&lt;/em&gt; by measuring its inference time and memory usage under different configurations of $d$ and $N$:
&lt;figure &gt;
&lt;div class="flex justify-center "&gt;
&lt;div class="w-full" &gt;
&lt;img alt="Empirical cutoff points N_0 and N_1"
srcset="https://nauen-it.de/publications/taylor-shift/images/empirical_cutoffs_hu_d607fab7565c3998.webp 320w, https://nauen-it.de/publications/taylor-shift/images/empirical_cutoffs_hu_97842ffc055fe8e8.webp 480w, https://nauen-it.de/publications/taylor-shift/images/empirical_cutoffs_hu_598fb8cfb4bee0f6.webp 760w"
sizes="(max-width: 480px) 100vw, (max-width: 768px) 90vw, (max-width: 1024px) 80vw, 760px"
src="https://nauen-it.de/publications/taylor-shift/images/empirical_cutoffs_hu_d607fab7565c3998.webp"
width="760"
height="415"
loading="lazy" data-zoomable /&gt;&lt;/div&gt;
&lt;/div&gt;&lt;/figure&gt;
We observe that the empirical estimate for the memory transition point $\hat N_1$ coincides almost exactly with the theoretical value $N_1$, with an error of at most $0.6\\%$.
The difference between the empirical speed transition point $\hat N_0$ and the theoretical one $N_0$ is approximately proportional to $d$: $\hat N_0 - N_0 \approx 18 d$.
We hypothesize that the more sequential nature of &lt;em&gt;efficient-TaylorShift&lt;/em&gt; results in more, costly reads and writes in GPU memory.
It might indicate possible efficiency gains for &lt;em&gt;efficient-TaylorShift&lt;/em&gt; from a low-level IO-efficient implementation.&lt;/p&gt;
&lt;h2 id="performance-of-a-transformer-with-taylorshift"&gt;Performance of a Transformer with TaylorShift&lt;/h2&gt;
&lt;p&gt;We test the accuracy of multiple (efficient) Transformers on a set of 5 tasks from the Long Range Arena benchmark
, as well as ImageNet classification at two model sizes.
We use the same standard hyperparameters for all models. Models with * had to be trained in full instead of mixed precision. All tasks are classitication tasks and the table shows accuracy in percent.&lt;/p&gt;
&lt;div class="scroll-container"&gt;&lt;table class="sortable table"&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th style="text-align: left"&gt;Model&lt;/th&gt;
&lt;th style="text-align: center"&gt;CIFAR (Pixel)&lt;/th&gt;
&lt;th style="text-align: center"&gt;IMDB (Byte)&lt;/th&gt;
&lt;th style="text-align: center"&gt;ListOps&lt;/th&gt;
&lt;th style="text-align: center"&gt;ImageNet (Ti)&lt;/th&gt;
&lt;th style="text-align: center"&gt;ImageNet (S)&lt;/th&gt;
&lt;th style="text-align: center"&gt;Average&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;Linformer
&lt;/td&gt;
&lt;td style="text-align: center"&gt;29.2&lt;/td&gt;
&lt;td style="text-align: center"&gt;58.1&lt;/td&gt;
&lt;td style="text-align: center"&gt;&amp;ndash;&lt;/td&gt;
&lt;td style="text-align: center"&gt;64.3&lt;/td&gt;
&lt;td style="text-align: center"&gt;76.3&lt;/td&gt;
&lt;td style="text-align: center"&gt;(57.0)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;RFA
&lt;/td&gt;
&lt;td style="text-align: center"&gt;44.9&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;ins&gt;65.8&lt;/ins&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&amp;ndash;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&amp;ndash;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&amp;ndash;&lt;/td&gt;
&lt;td style="text-align: center"&gt;(55.4)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;Performer
&lt;/td&gt;
&lt;td style="text-align: center"&gt;34.2*&lt;/td&gt;
&lt;td style="text-align: center"&gt;65.6*&lt;/td&gt;
&lt;td style="text-align: center"&gt;35.4*&lt;/td&gt;
&lt;td style="text-align: center"&gt;62.0*&lt;/td&gt;
&lt;td style="text-align: center"&gt;67.1*&lt;/td&gt;
&lt;td style="text-align: center"&gt;52.9&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;Reformer
&lt;/td&gt;
&lt;td style="text-align: center"&gt;44.8&lt;/td&gt;
&lt;td style="text-align: center"&gt;63.9&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;47.6&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;73.6&lt;/td&gt;
&lt;td style="text-align: center"&gt;76.2*&lt;/td&gt;
&lt;td style="text-align: center"&gt;61.2&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;Nyströmformer
&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;49.4&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;65.6&lt;/td&gt;
&lt;td style="text-align: center"&gt;44.5&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;ins&gt;75.0&lt;/ins&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;78.3*&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;ins&gt;62.6&lt;/ins&gt;&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;EVA
&lt;/td&gt;
&lt;td style="text-align: center"&gt;46.1&lt;/td&gt;
&lt;td style="text-align: center"&gt;64.0&lt;/td&gt;
&lt;td style="text-align: center"&gt;45.3&lt;/td&gt;
&lt;td style="text-align: center"&gt;73.4&lt;/td&gt;
&lt;td style="text-align: center"&gt;78.2&lt;/td&gt;
&lt;td style="text-align: center"&gt;61.4&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;Transformer
&lt;/td&gt;
&lt;td style="text-align: center"&gt;44.7&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;ins&gt;65.8&lt;/ins&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;46.0&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;75.6&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;ins&gt;79.1&lt;/ins&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;62.2&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: left"&gt;&lt;strong&gt;&lt;em&gt;TaylorShift&lt;/em&gt; (ours)&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;ins&gt;47.6&lt;/ins&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;66.2&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;ins&gt;46.1&lt;/ins&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;ins&gt;75.0&lt;/ins&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;79.3&lt;/strong&gt;&lt;/td&gt;
&lt;td style="text-align: center"&gt;&lt;strong&gt;62.8&lt;/strong&gt;&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;/div&gt;
&lt;p&gt;This shows &lt;em&gt;TaylorShift&amp;rsquo;s&lt;/em&gt; consistent performance across various datasets.
It surpasses all other models on 4 out of 5 datasets, positioning itself as a good all-rounder model.
We observe a notable increase of $4.3\\%$ when transitioning from size Ti to S on ImageNet, in contrast to $3.5\\%$ for the Transformer, which highlights &lt;em&gt;TaylorShifts&lt;/em&gt; scalability.&lt;/p&gt;
&lt;h2 id="number-of-attention-heads"&gt;Number of attention heads&lt;/h2&gt;
&lt;p&gt;We train &lt;em&gt;TaylorShift&lt;/em&gt; models on the pixel-level CIFAR10 task to see how accuracy and efficiency change.
All models have the default $d_\text{emb} = 256$ with $d = \frac{d_\text{emb}}{h}$ in the attention mechanism.
The default is $h = 4$.&lt;/p&gt;
&lt;div class="scroll-container"&gt;&lt;table class="sortable table"&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th style="text-align: right"&gt;$h$&lt;/th&gt;
&lt;th style="text-align: right"&gt;$d$&lt;/th&gt;
&lt;th style="text-align: right"&gt;Acc [%]&lt;/th&gt;
&lt;th style="text-align: right"&gt;&lt;em&gt;dir-TS&lt;/em&gt;&lt;br&gt;TP [ims/s]&lt;/th&gt;
&lt;th style="text-align: right"&gt;&lt;em&gt;dir-TS&lt;/em&gt;&lt;br&gt;Mem [MiB@16]&lt;/th&gt;
&lt;th style="text-align: right"&gt;&lt;em&gt;eff-TS&lt;/em&gt;&lt;br&gt;TP [ims/s]&lt;/th&gt;
&lt;th style="text-align: right"&gt;&lt;em&gt;eff-TS&lt;/em&gt;&lt;br&gt;Mem [MiB@16]&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td style="text-align: right"&gt;4&lt;/td&gt;
&lt;td style="text-align: right"&gt;64&lt;/td&gt;
&lt;td style="text-align: right"&gt;47.1&lt;/td&gt;
&lt;td style="text-align: right"&gt;12060&lt;/td&gt;
&lt;td style="text-align: right"&gt;596&lt;/td&gt;
&lt;td style="text-align: right"&gt;2975&lt;/td&gt;
&lt;td style="text-align: right"&gt;840&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: right"&gt;8&lt;/td&gt;
&lt;td style="text-align: right"&gt;32&lt;/td&gt;
&lt;td style="text-align: right"&gt;47.5&lt;/td&gt;
&lt;td style="text-align: right"&gt;7657&lt;/td&gt;
&lt;td style="text-align: right"&gt;1111&lt;/td&gt;
&lt;td style="text-align: right"&gt;5749&lt;/td&gt;
&lt;td style="text-align: right"&gt;585&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: right"&gt;16&lt;/td&gt;
&lt;td style="text-align: right"&gt;16&lt;/td&gt;
&lt;td style="text-align: right"&gt;47.3&lt;/td&gt;
&lt;td style="text-align: right"&gt;4341&lt;/td&gt;
&lt;td style="text-align: right"&gt;2135&lt;/td&gt;
&lt;td style="text-align: right"&gt;9713&lt;/td&gt;
&lt;td style="text-align: right"&gt;459&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: right"&gt;32&lt;/td&gt;
&lt;td style="text-align: right"&gt;8&lt;/td&gt;
&lt;td style="text-align: right"&gt;46.9&lt;/td&gt;
&lt;td style="text-align: right"&gt;2528&lt;/td&gt;
&lt;td style="text-align: right"&gt;4187&lt;/td&gt;
&lt;td style="text-align: right"&gt;14087&lt;/td&gt;
&lt;td style="text-align: right"&gt;397&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: right"&gt;64&lt;/td&gt;
&lt;td style="text-align: right"&gt;4&lt;/td&gt;
&lt;td style="text-align: right"&gt;45.9&lt;/td&gt;
&lt;td style="text-align: right"&gt;1235&lt;/td&gt;
&lt;td style="text-align: right"&gt;8291&lt;/td&gt;
&lt;td style="text-align: right"&gt;13480&lt;/td&gt;
&lt;td style="text-align: right"&gt;125&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;/div&gt;
&lt;p&gt;We see that increasing the number of attention heads $h$ increases the speed and decreases the memory needed by &lt;em&gt;efficient-TaylorShift&lt;/em&gt;, as predicted.
Additionally, we find that it also increases the performance up to a certain point.
Until there, we have a win-win-win situation with a faster, more lightweight &lt;em&gt;and&lt;/em&gt; more accurate model.
After that the number of heads can be used to trade off accuracy against the amount compute needed.&lt;/p&gt;
&lt;h1 id="conclusion--outlook"&gt;Conclusion &amp;amp; Outlook&lt;/h1&gt;
&lt;p&gt;We introduced &lt;em&gt;TaylorShift&lt;/em&gt; a novel efficient Transformer model.
It offers significant computational advantages without sacrificing performance.
By approximating the exponential function, TaylorShift achieves linear time and memory complexity, making it ideal for long sequences.
Our experiments demonstrate its superiority over standard Transformers in terms of speed, memory efficiency, and even accuracy.&lt;/p&gt;
&lt;p&gt;As we move forward, we envision TaylorShift opening up new possibilities for tackling challenging tasks involving lengthy sequences.
From high-resolution image processing to large-scale document analysis, TaylorShift&amp;rsquo;s efficiency and versatility make it a promising tool for the future of efficient Transformer models.&lt;/p&gt;
&lt;div class="callout flex items-baseline gap-2 px-3 py-2 mb-4 rounded-md border-l-4 bg-primary-50 dark:bg-primary-900/30 border-primary-500"
data-callout="note"
data-callout-metadata=""&gt;
&lt;span class="callout-icon shrink-0 translate-y-0.5 text-primary-600 dark:text-primary-400"&gt;
&lt;svg height="20" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="m11.25 11.25l.041-.02a.75.75 0 0 1 1.063.852l-.708 2.836a.75.75 0 0 0 1.063.853l.041-.021M21 12a9 9 0 1 1-18 0a9 9 0 0 1 18 0m-9-3.75h.008v.008H12z"/&gt;&lt;/svg&gt;
&lt;/span&gt;
&lt;div class="callout-content text-base dark:text-neutral-300"&gt;
&lt;div class="callout-body"&gt;&lt;p&gt;For more details, see the
or the
.&lt;/p&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;/div&gt;
&lt;h1 id="references"&gt;References&lt;/h1&gt;
&lt;ol&gt;
&lt;li&gt;&lt;a name="Performer"&gt;K.M. Choromanski, V. Likhosherstov, D. Dohan, X. Song, A. Gane, T. Sarlos, P. Hawkins, J.Q. Davis, A. Mohiuddin, L. Kaiser, D.B. Belanger, L.J. Colwell, and A. Weller&lt;/a&gt; &amp;ldquo;&lt;em&gt;Rethinking attention with performers&lt;/em&gt;&amp;rdquo;. ICLR 2021.&lt;/li&gt;
&lt;li&gt;&lt;a name="Reformer"&gt;N. Kitaev, L. Kaiser, and A. Levskaya.&lt;/a&gt; &amp;ldquo;&lt;em&gt;Reformer: The efficient transformer&lt;/em&gt;&amp;rdquo;. ICLR 2020.&lt;/li&gt;
&lt;li&gt;&lt;a name="RFA"&gt;H. Peng, N. Pappas, D. Yogatama, R. Schwartz, N.A. Smith, and L. Kong&lt;/a&gt; &amp;ldquo;&lt;em&gt;Random feature attention&lt;/em&gt;&amp;rdquo;. ICLR 2021.&lt;/li&gt;
&lt;li&gt;&lt;a name="LRA"&gt;Y. Tay, M. Dehghani, S. Abnar, Y. Shen, D. Bahri, P. Pham, J. Rao, L. Yang, S. Ruder, and D. Metzler&lt;/a&gt; &amp;ldquo;&lt;em&gt;Long range arena: A benchmark for efficient transformers&lt;/em&gt;&amp;rdquo; ICLR 2021.&lt;/li&gt;
&lt;li&gt;&lt;a name="Transformer"&gt;A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A.N. Gomez, L. Kaiser, and I. Polosukhin&lt;/a&gt; &amp;ldquo;&lt;em&gt;Attention is all you need&lt;/em&gt;&amp;rdquo;. NeurIPS 2017.&lt;/li&gt;
&lt;li&gt;&lt;a name="Linformer"&gt;S. Wang, B.Z. Li, M. Khabsa, H. Fang, and H. Ma&lt;/a&gt; &amp;ldquo;&lt;em&gt;Linformer: Self-attention with linear complexity&lt;/em&gt;&amp;rdquo;. ArXiv Prerint 2020.&lt;/li&gt;
&lt;li&gt;&lt;a name="Nystromformer"&gt;Y. Xiong, Z. Zeng, R. Chakraborty, M. Tan ,G. Fung, Y. Li, and V. Singh&lt;/a&gt; &amp;ldquo;&lt;em&gt;Nyströmformer: A nyström-based algorithm for approximating self-attention&lt;/em&gt;&amp;rdquo;. AAAI 2021.&lt;/li&gt;
&lt;li&gt;&lt;a name="EVA"&gt;L. Zheng, J. Yuan, C. Wang, and L. Kong&lt;/a&gt; &amp;ldquo;&lt;em&gt;Efficient attention via control variates&lt;/em&gt;&amp;rdquo;. ICLR 2023.&lt;/li&gt;
&lt;/ol&gt;
&lt;p&gt;&lt;strong&gt;Associated Projects:&lt;/strong&gt;
,
&lt;/p&gt;</description></item><item><title>Distill the Best, Ignore the Rest: Improving Dataset Distillation with Loss-Value-Based Pruning</title><link>https://nauen-it.de/publications/distill-best-ignore-rest/</link><pubDate>Mon, 18 Nov 2024 13:00:00 +0000</pubDate><guid>https://nauen-it.de/publications/distill-best-ignore-rest/</guid><description>
&lt;div class="callout flex items-baseline gap-2 px-3 py-2 mb-4 rounded-md border-l-4 bg-primary-50 dark:bg-primary-900/30 border-primary-500"
data-callout="note"
data-callout-metadata=""&gt;
&lt;span class="callout-icon shrink-0 translate-y-0.5 text-primary-600 dark:text-primary-400"&gt;
&lt;svg height="20" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="m11.25 11.25l.041-.02a.75.75 0 0 1 1.063.852l-.708 2.836a.75.75 0 0 0 1.063.853l.041-.021M21 12a9 9 0 1 1-18 0a9 9 0 0 1 18 0m-9-3.75h.008v.008H12z"/&gt;&lt;/svg&gt;
&lt;/span&gt;
&lt;div class="callout-content text-base dark:text-neutral-300"&gt;
&lt;div class="callout-body"&gt;&lt;p&gt;For more information, see the
.&lt;/p&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;/div&gt;
&lt;p&gt;&lt;strong&gt;Associated Projects:&lt;/strong&gt;
,
,
&lt;/p&gt;</description></item><item><title>Just Leaf It: Accelerating Diffusion Classifiers with Hierarchical Class Pruning</title><link>https://nauen-it.de/publications/just-leaf-it/</link><pubDate>Mon, 18 Nov 2024 13:00:00 +0000</pubDate><guid>https://nauen-it.de/publications/just-leaf-it/</guid><description>
&lt;div class="callout flex items-baseline gap-2 px-3 py-2 mb-4 rounded-md border-l-4 bg-primary-50 dark:bg-primary-900/30 border-primary-500"
data-callout="note"
data-callout-metadata=""&gt;
&lt;span class="callout-icon shrink-0 translate-y-0.5 text-primary-600 dark:text-primary-400"&gt;
&lt;svg height="20" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="m11.25 11.25l.041-.02a.75.75 0 0 1 1.063.852l-.708 2.836a.75.75 0 0 0 1.063.853l.041-.021M21 12a9 9 0 1 1-18 0a9 9 0 0 1 18 0m-9-3.75h.008v.008H12z"/&gt;&lt;/svg&gt;
&lt;/span&gt;
&lt;div class="callout-content text-base dark:text-neutral-300"&gt;
&lt;div class="callout-body"&gt;&lt;p&gt;For more information, see the
.&lt;/p&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;/div&gt;
&lt;p&gt;&lt;strong&gt;Associated Projects:&lt;/strong&gt;
,
,
&lt;/p&gt;</description></item><item><title>A Low-Resolution Image is Worth 1x1 Words: Enabling Fine Image Super-Resolution with Transformers and TaylorShift</title><link>https://nauen-it.de/publications/taylor-shift-super-resolution/</link><pubDate>Fri, 15 Nov 2024 13:00:00 +0000</pubDate><guid>https://nauen-it.de/publications/taylor-shift-super-resolution/</guid><description>
&lt;div class="callout flex items-baseline gap-2 px-3 py-2 mb-4 rounded-md border-l-4 bg-primary-50 dark:bg-primary-900/30 border-primary-500"
data-callout="note"
data-callout-metadata=""&gt;
&lt;span class="callout-icon shrink-0 translate-y-0.5 text-primary-600 dark:text-primary-400"&gt;
&lt;svg height="20" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="m11.25 11.25l.041-.02a.75.75 0 0 1 1.063.852l-.708 2.836a.75.75 0 0 0 1.063.853l.041-.021M21 12a9 9 0 1 1-18 0a9 9 0 0 1 18 0m-9-3.75h.008v.008H12z"/&gt;&lt;/svg&gt;
&lt;/span&gt;
&lt;div class="callout-content text-base dark:text-neutral-300"&gt;
&lt;div class="callout-body"&gt;&lt;p&gt;This work builds on the
attention mechanism.&lt;/p&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;/div&gt;
&lt;div class="callout flex items-baseline gap-2 px-3 py-2 mb-4 rounded-md border-l-4 bg-primary-50 dark:bg-primary-900/30 border-primary-500"
data-callout="note"
data-callout-metadata=""&gt;
&lt;span class="callout-icon shrink-0 translate-y-0.5 text-primary-600 dark:text-primary-400"&gt;
&lt;svg height="20" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="m11.25 11.25l.041-.02a.75.75 0 0 1 1.063.852l-.708 2.836a.75.75 0 0 0 1.063.853l.041-.021M21 12a9 9 0 1 1-18 0a9 9 0 0 1 18 0m-9-3.75h.008v.008H12z"/&gt;&lt;/svg&gt;
&lt;/span&gt;
&lt;div class="callout-content text-base dark:text-neutral-300"&gt;
&lt;div class="callout-body"&gt;&lt;p&gt;For more information, see the
.&lt;/p&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;/div&gt;
&lt;p&gt;&lt;strong&gt;Associated Projects:&lt;/strong&gt;
,
,
&lt;/p&gt;</description></item><item><title>Stochastic Control with Signatures</title><link>https://nauen-it.de/publications/signature-stochastic-control/</link><pubDate>Mon, 03 Jun 2024 00:00:00 +0000</pubDate><guid>https://nauen-it.de/publications/signature-stochastic-control/</guid><description>&lt;p&gt;This work builds on my
.&lt;/p&gt;
&lt;div class="callout flex items-baseline gap-2 px-3 py-2 mb-4 rounded-md border-l-4 bg-primary-50 dark:bg-primary-900/30 border-primary-500"
data-callout="note"
data-callout-metadata=""&gt;
&lt;span class="callout-icon shrink-0 translate-y-0.5 text-primary-600 dark:text-primary-400"&gt;
&lt;svg height="20" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="m11.25 11.25l.041-.02a.75.75 0 0 1 1.063.852l-.708 2.836a.75.75 0 0 0 1.063.853l.041-.021M21 12a9 9 0 1 1-18 0a9 9 0 0 1 18 0m-9-3.75h.008v.008H12z"/&gt;&lt;/svg&gt;
&lt;/span&gt;
&lt;div class="callout-content text-base dark:text-neutral-300"&gt;
&lt;div class="callout-body"&gt;&lt;p&gt;For more information, see the
.&lt;/p&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;/div&gt;</description></item><item><title>Stochastic Optimal Control using Signatures</title><link>https://nauen-it.de/publications/master-thesis-signatures/</link><pubDate>Thu, 09 Jun 2022 00:00:00 +0000</pubDate><guid>https://nauen-it.de/publications/master-thesis-signatures/</guid><description>&lt;h1 id="1-introduction"&gt;1 Introduction&lt;/h1&gt;
&lt;p&gt;In this thesis, we consider a stochastic control problem of the form
&lt;/p&gt;
$$
dY_t = \mu_t b(Y_t) dt + \sigma(Y_t) dB_t,
$$&lt;p&gt;
where $\mu_t$ is an $\mathcal F_t = \sigma(B_s | s \leq t)$-measurable, continuous process we have some control over.
An SDE of this form can be found when one considers a noisy process, where only some control on the drift, i.e. the average direction is given.
This control manifests itself in the function $\mu_t : [0, T] \to \mathbb R$.&lt;/p&gt;
&lt;p&gt;A toy example for a problem of this kind could be modeling navigating on the seas or in space, where the random part is the combined influence of winds and currents on a boat and $\mu_t$ represents the direction of the rudder, or in the space example, the randomness represents course altering events like solar winds and $\mu_t$ is the direction or strength of thrust.
A similar optimal control problem with control in the drift was considered in (Diehl, Fritz, and Gassiat, 2017), which investigates the value function to find a dual problem.
This was the first paper on stochastic optimal control, using rough path analysis.&lt;/p&gt;
&lt;p&gt;We now use the ansatz
\begin{align*}
&amp;amp;\mu_t = \Theta(B|_{[0, t]}) &amp;amp;\Theta \in C( \Lambda_T, \mathbb R) =: \mathcal T,
\end{align*}
with $\Lambda_T$ being the space of stopped rough paths up to time $T$ (see Definition 5.2).&lt;/p&gt;
&lt;p&gt;This gives the SDE
&lt;/p&gt;
$$
dY_t^\mu = \underbrace{\Theta(\hat B|_{[0, t]}) b(Y_t^\mu)}\_{= \mu_t} dt + \sigma(Y_t^\mu) dB_t.
$$&lt;p&gt;We can now define a loss-function like
\begin{align*}
L(Y^\mu) := \mathbb E(Y_T^\mu)&lt;sup&gt;2 + \mathbb E(\left|Y_T&lt;/sup&gt;\mu\right|^2),
\end{align*}
but in general all losses $L : C([0, T], \mathbb R^m) \to \mathbb R^+$ Lipschitz or Hölder continuous are possible.&lt;/p&gt;
&lt;p&gt;The question we want to answer is:
What is $\inf_\mu \mathbb E[L(Y^\mu)]$ and what does the corresponding $\mu$ (and $\Theta$) look like?
It is the question of an optimal way to act, while counteracting random noise.&lt;/p&gt;
&lt;p&gt;First, we need to understand our main problem SDE.
This is a shorthand notation for
\begin{align}\label{eq:main_problem_integral_form}
Y_t - Y_0 = \int_0&lt;sup&gt;t \mu_\tau b(Y_\tau) d\tau + \int_0&lt;/sup&gt;t \sigma(Y_\tau) dB_\tau.
\end{align}
The second integral here can be seen as an Itô integral. However, we will view it as the integral over a rough path, a so-called rough integral.
This is a generalization of the Itô-Map, which can also incorporate other types of stochastic integrals, like the Stratonovich-Integral.
This change of perspective is useful since we want to look at the so-called signatures of some processes, which are defined naturally in the context of rough paths.&lt;/p&gt;
&lt;p&gt;The theory of rough paths was first introduced in the 1990s by Terry Lyons.
It is an elegant framework for path-wise integration with rough driving signals and is therefore suited to a general class of stochastic processes, like Brownian motion or fractional Brownian motion.
In particular rough integrals are a generalization of Young&amp;rsquo;s theory of integration.
An important aspect of the theory is the continuity of the solution map of rough differential equations, which is not given in the classical case of Itô SDEs, where the solution map is measurable, but not continuous.&lt;/p&gt;
&lt;p&gt;In addition to theoretical advances in SDEs, there were additional tools developed for rough paths, most notably the signature.
The signature $\mathbb X^{&lt; \infty}$ of a path $x: [0, T] \to \mathbb R^n$ is a collection of iterated integrals of all components of the path against each other;
\begin{align*}
\int_{0 \leq t_1 \leq &amp;hellip; \leq t_k \leq t} dx_{t_1}&lt;sup&gt;{i_1} &amp;hellip; dx_{t_k}&lt;/sup&gt;{i_k}
\end{align*}
for $k \in \mathbb N$ and $i_1, ..., i_k \in \lbrace 1, ..., n\rbrace$.
Now, the values of the signature have to be defined up to a certain level $k$, which depends on the roughness of the underlying path.
To see why this is true, one can consider the differences between the Itô and Stratonovich integrals, which both are fair definitions of integrals with respect to Brownian motion.
We have
\begin{align*}
\int_{0 \leq t_1 \leq t_2 \leq T} dB_{t_1} dB_{t_2} = \int_0&lt;sup&gt;T B_t dB_t = \frac{B_T&lt;/sup&gt;2}{2} + \frac{T}{2},
\end{align*}
but also
\begin{align*}
\int_{0 \leq t_1 \leq t_2 \leq T} \circ dB_{t_1} \circ dB_{t_2} = \int_0&lt;sup&gt;T B_t \circ dB_t = \frac{B_T&lt;/sup&gt;2}{2},
\end{align*}
which makes it clear, that there is not one single way of defining the signature of a process.
This is why, when working with iterated integrals, one has to set one way of calculation.
The theory of rough paths gives a framework for doing exactly that.
The signature of a path is important because the signature at time $t$ determines the whole path up to time $t$ up to so-called tree-like extensions.
In particular the signature of an augmented rough path, i.e. a path $x_t = (x_t^{(1)}, ..., x_t^{(n)})$ with an additional dimension that represents the time
\begin{align*}
\hat x_t = (x_t^{(1)}, &amp;hellip;, x_t^{(n)}, t) \in \mathbb R^{n + 1}
\end{align*}
is unique.
This makes the signature an important tool in machine learning as a model-free way to extract features from time-series data, like audio, speech, or character drawing.
As such it has been used successfully in several machine learning applications including Chinese character recognition or even medical tasks like the recognition of mental disorders.&lt;/p&gt;
&lt;p&gt;The property of injectiveness of the signature map also makes it important to us and is why we take the following ansatz for answering the question from above:
&lt;/p&gt;
$$
\Theta(\hat B|\_{[0, t]}) = \langle \ell, \hat{B}\_{0, t}^{&lt; \infty} \rangle.
$$&lt;p&gt;
Here, $\hat{B}_{0, t}^{&lt; \infty}$ is the signature of the augmented path of Brownian motion.
In this, we will follow the reasoning of (Kalsi, Lyons, and Arribas, 2020) and (Bayer et al., 2022), where it was shown that similar control problems of optimal trading speed and optimal stopping can be solved by just using linear maps of the path signature.&lt;/p&gt;
&lt;p&gt;The main result of this thesis will be&lt;/p&gt;
&lt;h3 id="theorem-56"&gt;Theorem 5.6:&lt;/h3&gt;
&lt;p&gt;&lt;em&gt;Let $2 \leq p &lt; 3$ and let $\mathbb P$ be a probability measure on $\left( \hat \Omega^p_T, \mathcal B(\hat \Omega^p_T) \right)$.&lt;/em&gt;
&lt;em&gt;Let $Y^\mu$ be the unique solution to&lt;/em&gt;
\begin{align*}
dY = \mu_t b(Y_t) dt + \sigma(Y_t) d\mathbf x
\end{align*}
&lt;em&gt;started at $\xi \in \mathbb R^m$, with $\mu \in \mathcal T$, $b$ Lipschitz, and $\sigma \in C^3_b(\mathbb R^m, \mathbb R^{m \times n})$.&lt;/em&gt;
&lt;em&gt;Here, the $\mathbf x$ is a random geometric $p$-rough path with distribution determined by $\mathbb P$.&lt;/em&gt;
&lt;em&gt;It holds&lt;/em&gt;
\begin{align*}
\inf_{\mu \in \mathcal T} \mathbb E [L(Y^\mu)] = \inf_{\mu \in \mathcal{T}_{sig}} \mathbb E [L(Y^\mu)]
\end{align*}
&lt;em&gt;for a loss function $L : C([0, T], \mathbb R^m) \to \mathbb R$ bounded and $\alpha$-Hölder for some $\alpha &gt; 0$.&lt;/em&gt;&lt;/p&gt;
&lt;p&gt;Here $\mathcal T = C(\Lambda_T, \mathbb R)$ is the set of all continuous functions of the path up to some time $t \in [0, T]$, while $\mathcal{T}\_{sig}$ is the set of all functions of the form $\langle \ell, \hat{\mathbb{X}}_{0, t}^{&lt; \infty} \rangle$.
The theorem, therefore, says, that the optimal control problem can be solved by considering just linear maps of the signature of the augmented path.
The statement will then also be extended to Itô integrals, as considered in the problem SDE, in Theorem 5.6.&lt;/p&gt;
&lt;p&gt;Using these theorems, we can tackle our question numerically by modeling $\mu\_t = \langle \ell, \hat{B}\_{0, t}^{\leq k} \rangle$ to be a linear map.
Here, we drop from the infinite-dimensional, full signature $\hat B_{0, t}^{&lt; \infty}$ to the finite-dimensional, truncated signature $\hat B_{0, t}^{\leq k}$ for numerical reasons.
This is a good approximation, as
\begin{align*}
\left|\left|{\hat B_{s, t}&lt;sup&gt;k}\right|\right| \leq C \frac{\omega(s, t)&lt;/sup&gt;{\frac k p}}{\left( \frac{k}{p} \right) ! }
\end{align*}
(see Theorem 3.7 in Lyons, Caruana, and Lévy, 2007), i.e. the norms of additional signature levels decrease like $\frac{1}{k !}$.
We can approximate the RDE&amp;rsquo;s solution by using a Milstein scheme (Algorithm 3) on a discrete time-grid
\begin{align*}
0 = t_0 &amp;lt; t_1 &amp;lt; &amp;hellip; &amp;lt; t_k = T
\end{align*}
and estimating the expected loss $\mathbb E[L(\theta)]$ after many such simulations.
Using the backpropagation algorithm then can lead us arbitrarily close to the optimal solution $\mu_t$.&lt;/p&gt;
&lt;p&gt;At first, we will introduce the theory of rough paths with its basic facts and definitions and derive rough integrals as a limit of Riemann-like sums in Section 2.
Throughout the thesis, we will work with general rough paths with finite $p$-variation for $p \in [2, 3)$, where Young integration breaks down.
For ease of notation, we will introduce a tensor calculus.
In this section, a general setting of controlled rough paths is also established that deals with all kinds of rough paths as opposed to the theory of (Fritz, and Hairer, 2020) only considering $\alpha$-Hölder paths.
After that, in Section 3, we will deal with rough differential equations (RDEs).
We will prove the existence and uniqueness of solutions in the usual way via Picard iteration, but then extend the theory to RDEs with drift term, where we will only require very mild assumptions on the drift term, such that we can incorporate all RDEs of the form seen in the problem SDE for $b$ Lipschitz and $\mu$ continuous.
We also investigate the stability of RDEs in the drift term.
After having introduced RDEs, we will move on to signatures in Section 4, where we will see the basic definitions, along with a proof of the shuffle identity for geometric rough paths.
This is directly followed by the proof of our main theorem, Theorem 5.6, in Section 5.
Here, we will exploit the notion of stopped rough paths, as well as Lemma 5.5 which has also been used in (Kalsi, Lyons, and Arribas, 2020) and (Bayer et al., 2022) to show the density of signature controls on compact sets of arbitrary high probability ($&lt; 1$).
We then expand the main theorem to work with Itô-integrals.
After proving the theoretical results, we will go on to state numerical algorithms which can be used for approximation and which are also implemented and can be viewed on
, as well as some convergence results for said algorithms in Section 6.
Then, in Section 7, we test our implementation against a julia reference implementation based on two SDE problems.
We also use our framework to solve an optimal asset allocation problem in the Black-Scholes model.
The SDE of this problem is of a different structure than we had before, and we argue why the same approach we took (approximating $\mu \in \mathcal T$ by $\mu \in \mathcal T_{sig}$) can also be done when one has combined control over the drift and volatility terms.
Here, we use the Markov property of Brownian motion and neural networks to choose the control term to be $C(\mathbb R^{m + 1}, \mathbb R)$ instead of a linear function of the signature of the process.
In the end (Section 8) we will discuss some extensions of the problem, as well as different possibilities of defining the Gubinelli derivative of RDE solutions when dealing with a drift term.&lt;/p&gt;
&lt;div class="callout flex items-baseline gap-2 px-3 py-2 mb-4 rounded-md border-l-4 bg-primary-50 dark:bg-primary-900/30 border-primary-500"
data-callout="note"
data-callout-metadata=""&gt;
&lt;span class="callout-icon shrink-0 translate-y-0.5 text-primary-600 dark:text-primary-400"&gt;
&lt;svg height="20" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="m11.25 11.25l.041-.02a.75.75 0 0 1 1.063.852l-.708 2.836a.75.75 0 0 0 1.063.853l.041-.021M21 12a9 9 0 1 1-18 0a9 9 0 0 1 18 0m-9-3.75h.008v.008H12z"/&gt;&lt;/svg&gt;
&lt;/span&gt;
&lt;div class="callout-content text-base dark:text-neutral-300"&gt;
&lt;div class="callout-body"&gt;&lt;p&gt;For more information, see the
.&lt;/p&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;/div&gt;
&lt;h2 id="references"&gt;References&lt;/h2&gt;
&lt;p&gt;(Diehl, Fritz, and Gassiat, 2017) &lt;em&gt;Joscha Diehl, Peter K. Fritz, and Paul Gassiat. &amp;lsquo;&amp;lsquo;Stochastic control with rough paths&amp;rsquo;&amp;rsquo;. In Applied Mathematics &amp;amp; Optimization 75, pp. 285-315, 2017.&lt;/em&gt;&lt;/p&gt;
&lt;p&gt;(Kalsi, Lyons, and Arribas, 2020) &lt;em&gt;Jasdeep Kalsi, Terry Lions, and Imanol Perez Arribas. &amp;lsquo;&amp;lsquo;Optimal execution with rough path signatures&amp;rsquo;&amp;rsquo;. In SIAM J. Financial Math 11, pp.470-493, 2020.&lt;/em&gt;&lt;/p&gt;
&lt;p&gt;(Bayer et al., 2022) &lt;em&gt;Christian Bayer et al. &amp;lsquo;&amp;lsquo;Optimal stopping with signatures&amp;rsquo;&amp;rsquo;. In: The Annals of Applied Probability, 2022.&lt;/em&gt;&lt;/p&gt;
&lt;p&gt;(Lyons, Caruana, and Lévy, 2007) &lt;em&gt;Terry J. Lyons, Michael J. Caruana, and Terry Lévy. &amp;lsquo;&amp;lsquo;Differential equations driven by rough paths&amp;rsquo;&amp;rsquo;. In: Lecture Notes in Mathematics, Springer Berlin Heidelberg, 2007.&lt;/em&gt;&lt;/p&gt;</description></item><item><title>Explaining Graph Neural Networks</title><link>https://nauen-it.de/publications/bachelor-thesis-gnns/</link><pubDate>Fri, 08 Oct 2021 00:00:00 +0000</pubDate><guid>https://nauen-it.de/publications/bachelor-thesis-gnns/</guid><description>&lt;h2 id="introduction"&gt;Introduction&lt;/h2&gt;
&lt;p&gt;In this bachelor thesis, we explore and evaluate different methods of explaining Graph Neural Networks (GNNs).
Graph Neural Networks are an emerging class of neural networks, that take graphs as their input data.
This is especially useful since graphs are highly flexible and powerful data structures, that can therefore express a set of different datapoints with complex relationships between them.
The motivation for developing graph neural networks comes from the overwhelming success of convolutional neural networks, which can be seen as a special case of GNNs, operating on pictures by exploiting neighborhood information, which can also be expressed as a graph.
Today graph neural networks are used in a wide array of domains, like the prediction of molecular properties in chemistry, drug discovery, or even diagnosis in medicine, to model the spread of disease, in recommendation systems, or natural language processing.&lt;/p&gt;
&lt;p&gt;But why would one want to explain these networks?
Methods for explaining neural models are used to perform a wide amount of tasks.
The first one is to debug the model and increase performance, as explanation methods can uncover model bias or spurious correlations in the training data.
These are then used to clean up or expand the training data or to adjust the model class, to archive better performance and generalization.&lt;/p&gt;
&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th style="text-align: center"&gt;Model&lt;/th&gt;
&lt;th style="text-align: center"&gt;Prediction&lt;/th&gt;
&lt;th style="text-align: center"&gt;Explanation&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td style="text-align: center"&gt;$A$&lt;/td&gt;
&lt;td style="text-align: center"&gt;Positive&lt;/td&gt;
&lt;td style="text-align: center"&gt;Even though the Icelandic &lt;span style="background-color: #FF6A00"&gt;scenery is incredibly stunning&lt;/span&gt;, the story can&amp;rsquo;t keep up, and therefore the overall experience is boring.&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: center"&gt;$B$&lt;/td&gt;
&lt;td style="text-align: center"&gt;Negative&lt;/td&gt;
&lt;td style="text-align: center"&gt;Even though the &lt;span style="background-color: #FF6A00"&gt;Icelandic scenery&lt;/span&gt; is incredibly stunning, the story can&amp;rsquo;t keep up, and therefore the &lt;span style="background-color: #FF6A00"&gt;overall experience&lt;/span&gt; is boring.&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td style="text-align: center"&gt;$C$&lt;/td&gt;
&lt;td style="text-align: center"&gt;Negative&lt;/td&gt;
&lt;td style="text-align: center"&gt;Even though the Icelandic scenery is incredibly stunning, the &lt;span style="background-color: #FF6A00"&gt;story can&amp;rsquo;t keep up&lt;/span&gt;, and therefore the &lt;span style="background-color: #FF6A00"&gt;overall experience is boring&lt;/span&gt;.&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;
&lt;p&gt;For example, if we want to categorize reviews into positive and negative ones, we are also interested in exactly why our model decides that a given review is positive or negative.
Using this information we can more accurately judge the models&amp;rsquo; performance, by checking if its predictions are correct for the right reasons.
The example explanations in the table above reveal that model $B$ is correct for the wrong reason, while model $C$ is correct for the right reason.
Therefore, model $C$ should be deployed over model $B$ since we can expect $C$ to generalize better to new, unseen data.&lt;/p&gt;
&lt;p&gt;A second application area of explanations is to assess the suitability of a model for use in the real world.
This is especially important in high-stakes environments, such as medicine or law enforcement, where graph neural networks are used.
Therefore, explainability is also part of the approval process by a regulatory authority, like the European Union, or in some companies.
Another way explanation techniques are useful is by hinting on what to change in the input, to receive a different model output.
This is useful for example in loan approval if the client wants to receive information on what factors to change to be approved.&lt;/p&gt;
&lt;p&gt;One distinguishes two forms of explanations: global and local ones.
While global explanations are ways of explaining the model as a whole, it is often not feasible to construct such global explanations, especially when using a model with a lot of parameters, since it&amp;rsquo;s just too complex to be understood as a whole.
Therefore, we want to focus on local explanations.
These don&amp;rsquo;t attempt to explain the whole model, but just a single decision of the model given a certain input.
The explanations in the table above are local ones.&lt;/p&gt;
&lt;p&gt;Now it begs the question, how one can explain the decisions of a graph neural network.
To answer this question we will lay out the relevant techniques to generate attribution weights, as well as expand on them.
Attribution weights are ways of explaining neural models by associating a weight with different parts, or tokens, of the models&amp;rsquo; input.
These tokens could be pixels in a picture, words in a text, or nodes and/or edges in a graph.
The parts with high weight are seen as more important for the models&amp;rsquo; decision than those with low weights.
If all generated weights are zero or one, the technique is called a hard attribution technique.
These mark relevant parts of the input, as is the case in the table above.
When a range of real numbers is allowed as weights, the attributions are called soft.
We will focus on soft attribution techniques, as these provide a relation of importance on the inputs&amp;rsquo; tokens.&lt;/p&gt;
&lt;p&gt;The second question that arises is, which technique should one use to explain GNNs and how to judge if one technique is better than another?
To answer these questions, we first establish and explain the notion of graph neural networks, as well as different architectures.
Then we introduce some gradient-based attribution techniques and the interpretability by design approach of KEdge.
KEdge was introduced by
.
It works by sampling a mask for the edges of a graph via an approximation of the Bernoulli distribution.
This mask can then be used to generate attribution weights.
In the original paper, this approximation is based on the Kumaraswamy, or Kuma, distribution.
In our third chapter, we define some probability distributions to construct different approximations of the Bernoulli distribution, that we can use with KEdge.
We also talk about how to obtain node-level attribution weights from KEdge.
Then we introduce some metrics to measure the performance of the different attribution techniques, and in particular, we extend the notion of fidelity to soft attribution techniques by introducing integrated fidelity.&lt;/p&gt;
&lt;p&gt;In the main part of this thesis, we conduct three experiments.
The first two, to evaluate and compare the attribution techniques, as well as to see, what effects KEdge has on a model&amp;rsquo;s performance.
Here, we compare the accuracy of different models with and without KEdge, to see if there is a noticeable difference, depending on which underlying probability distribution we used.
We also compare the integrated fidelity values of all the attribution techniques we introduced before.
This is done on the node classification datasets Pubmed, Cora, and CiteSeer and the graph classification dataset MUTAG.
In the last experiment, we use our methods on a text dataset of movie reviews, to be able to visualize attribution weights and compare different metrics of evaluating attribution weights.&lt;/p&gt;
&lt;div class="callout flex items-baseline gap-2 px-3 py-2 mb-4 rounded-md border-l-4 bg-primary-50 dark:bg-primary-900/30 border-primary-500"
data-callout="note"
data-callout-metadata=""&gt;
&lt;span class="callout-icon shrink-0 translate-y-0.5 text-primary-600 dark:text-primary-400"&gt;
&lt;svg height="20" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"&gt;&lt;path fill="none" stroke="currentColor" stroke-linecap="round" stroke-linejoin="round" stroke-width="1.5" d="m11.25 11.25l.041-.02a.75.75 0 0 1 1.063.852l-.708 2.836a.75.75 0 0 0 1.063.853l.041-.021M21 12a9 9 0 1 1-18 0a9 9 0 0 1 18 0m-9-3.75h.008v.008H12z"/&gt;&lt;/svg&gt;
&lt;/span&gt;
&lt;div class="callout-content text-base dark:text-neutral-300"&gt;
&lt;div class="callout-body"&gt;&lt;p&gt;For more information, see the
.&lt;/p&gt;&lt;/div&gt;
&lt;/div&gt;
&lt;/div&gt;
&lt;h2 id="references"&gt;References&lt;/h2&gt;
&lt;p&gt;(Rathee et al., 2021) &lt;em&gt;Mandeep Rathee et al. &amp;lsquo;&amp;lsquo;Learned Sparsification for Interpretable Graph Neural Networks&amp;rsquo;&amp;rsquo;. In: arXiv: 2106.12920, 2021.&lt;/em&gt;&lt;/p&gt;</description></item></channel></rss>