This document defines a minimal spike implementation for integrating neural style transfer into Unreal Engine 5's rendering pipeline using ONNX Runtime with DirectML.
┌─────────────────────────────────────────────────────────────────────────────────┐
│ UNREAL ENGINE 5 RUNTIME │
├─────────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ │
│ │ Game Thread │ │ Render Thread │ │ RHI Thread │ │
│ │ │ │ │ │ │ │
│ │ • Actor Tick │ │ • Scene Proxy │ │ • GPU Commands │ │
│ │ • Component │───▶│ • Render Graph │───▶│ • Resource Mgmt │ │
│ │ Updates │ │ • View Extension│ │ │ │
│ └─────────────────┘ └────────┬────────┘ └─────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────────────────────┐ │
│ │ RENDER GRAPH (FRDGBuilder) │ │
│ │ │ │
│ │ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ │ │
│ │ │ GBuffer │──▶│ Lighting│──▶│ ★ STYLE │──▶│ Post │──▶│ Final │ │ │
│ │ │ Pass │ │ Pass │ │ TRANSFER│ │ Process │ │ Output │ │ │
│ │ └─────────┘ └─────────┘ └────┬────┘ └─────────┘ └─────────┘ │ │
│ │ │ │ │
│ └────────────────────────────────────┼─────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────────────────────┐ │
│ │ ★ STYLE TRANSFER SUBSYSTEM (NEW) │ │
│ │ │ │
│ │ ┌────────────────────┐ ┌────────────────────┐ │ │
│ │ │ FStyleTransferPass │ │ UStyleTransferSubsystem │ │
│ │ │ (Render Thread) │ │ (Game Thread) │ │ │
│ │ │ │ │ │ │ │
│ │ │ • Extract ROI │ │ • Owns ONNXRuntime │ │ │
│ │ │ • Copy to Staging │◀──▶│ • Model Loading │ │ │
│ │ │ • Trigger Inference│ │ • Session Mgmt │ │ │
│ │ │ • Composite Result │ │ • Style Parameters │ │ │
│ │ └────────────────────┘ └─────────┬──────────┘ │ │
│ │ │ │ │
│ └───────────────────────────────────────┼──────────────────────────────────┘ │
│ │ │
└──────────────────────────────────────────┼─────────────────────────────────────┘
│
▼
┌──────────────────────────────────────────────────────────────────────────────────┐
│ ONNX RUNTIME (External) │
│ │
│ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ │
│ │ Session Pool │ │ DirectML EP │ │ Model Cache │ │
│ │ │ │ (GPU Execution) │ │ │ │
│ │ • Thread-safe │───▶│ • ID3D12 Interop│◀───│ • .onnx files │ │
│ │ • Async inference │ • Tensor I/O │ │ • Preloaded │ │
│ └─────────────────┘ └─────────────────┘ └─────────────────┘ │
│ │
└──────────────────────────────────────────────────────────────────────────────────┘UNREAL DEFERRED RENDERING PIPELINE
══════════════════════════════════
Frame N:
─────────────────────────────────────────────────────────────────────────────────▶
┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐
│ Depth │ │ GBuffer │ │ Shadow │ │ Lighting│ │ Translu-│ │ Post │
│ PrePass │─▶│ Pass │─▶│ Maps │─▶│ Pass │─▶│ cency │─▶│ Process │
└─────────┘ └─────────┘ └─────────┘ └─────────┘ └─────────┘ └─────────┘
│
│ SceneColor (HDR)
▼
┌─────────────────────────────┐
│ ★ STYLE TRANSFER PASS │
│ │
│ Input: │
│ • SceneColor (FRDGTexture)│
│ • SceneDepth (optional) │
│ • ObjectMask (optional) │
│ │
│ Output: │
│ • StylizedColor │
│ (FRDGTexture) │
└──────────────┬──────────────┘
│
▼
┌─────────────────────────────┐
│ Composite Pass │
│ (Blend stylized regions │
│ back into SceneColor) │
└──────────────┬──────────────┘
│
▼
┌─────────────────────────────┐
│ Continue to Tonemapping, │
│ UI, Final Output │
└─────────────────────────────┘UE5 Class Hierarchy
═══════════════════
UEngineSubsystem
└── UStyleTransferSubsystem [GameInstance lifetime]
│
├── FONNXRuntimeWrapper [Owned, handles ONNX C API]
│ ├── OrtEnv*
│ ├── OrtSession*
│ └── OrtMemoryInfo*
│
└── FStyleModelAsset [UObject, .uasset wrapper for .onnx]
FSceneViewExtensionBase
└── FStyleTransferViewExtension [Registered per-world]
│
└── FStyleTransferSceneProxy [Per-view render state]
UActorComponent
└── UStylizedObjectComponent [Attached to stylized actors]
├── StyleModelAsset* [Reference to model]
├── StyleStrength [0.0 - 1.0]
├── UpdateFrequency [Every N frames]
└── BoundingBox [For ROI extraction]
AActor
└── AStyleZoneTrigger [Box trigger for style regions]
├── ZoneStyleId
└── TransitionDuration// StyleTransferSubsystem.h
#pragma once
#include "CoreMinimal.h"
#include "Subsystems/GameInstanceSubsystem.h"
#include "StyleTransferSubsystem.generated.h"
// Forward declarations
struct OrtEnv;
struct OrtSession;
struct OrtSessionOptions;
struct OrtMemoryInfo;
USTRUCT(BlueprintType)
struct FStyleTransferConfig
{
GENERATED_BODY()
UPROPERTY(EditAnywhere, BlueprintReadWrite)
int32 InputWidth = 256;
UPROPERTY(EditAnywhere, BlueprintReadWrite)
int32 InputHeight = 256;
UPROPERTY(EditAnywhere, BlueprintReadWrite)
bool bUseGPU = true;
UPROPERTY(EditAnywhere, BlueprintReadWrite)
int32 MaxConcurrentInferences = 1;
};
UCLASS()
class YOURGAME_API UStyleTransferSubsystem : public UGameInstanceSubsystem
{
GENERATED_BODY()
public:
virtual void Initialize(FSubsystemCollectionBase& Collection) override;
virtual void Deinitialize() override;
// Load a style model from disk
UFUNCTION(BlueprintCallable, Category = "Style Transfer")
bool LoadModel(const FString& ModelPath, FName ModelId);
// Unload a model
UFUNCTION(BlueprintCallable, Category = "Style Transfer")
void UnloadModel(FName ModelId);
// Check if ready for inference
UFUNCTION(BlueprintCallable, Category = "Style Transfer")
bool IsModelLoaded(FName ModelId) const;
// Get config
const FStyleTransferConfig& GetConfig() const { return Config; }
// Called from render thread - returns session for inference
OrtSession* GetSession(FName ModelId) const;
OrtEnv* GetEnvironment() const { return OrtEnvironment; }
OrtMemoryInfo* GetMemoryInfo() const { return MemoryInfo; }
private:
void InitializeONNXRuntime();
void ShutdownONNXRuntime();
FStyleTransferConfig Config;
// ONNX Runtime handles (created once, reused)
OrtEnv* OrtEnvironment = nullptr;
OrtMemoryInfo* MemoryInfo = nullptr;
// Model sessions keyed by ID
TMap<FName, OrtSession*> LoadedSessions;
// Thread safety
mutable FCriticalSection SessionLock;
};// StyleTransferViewExtension.h
#pragma once
#include "SceneViewExtension.h"
#include "RenderGraphBuilder.h"
#include "ScreenPass.h"
class UStyleTransferSubsystem;
class FStyleTransferViewExtension : public FSceneViewExtensionBase
{
public:
FStyleTransferViewExtension(const FAutoRegister& AutoRegister, UWorld* InWorld);
virtual ~FStyleTransferViewExtension();
//~ FSceneViewExtensionBase interface
virtual void SetupViewFamily(FSceneViewFamily& InViewFamily) override {}
virtual void SetupView(FSceneViewFamily& InViewFamily, FSceneView& InView) override {}
virtual void BeginRenderViewFamily(FSceneViewFamily& InViewFamily) override {}
// This is where we inject our pass
virtual void PrePostProcessPass_RenderThread(
FRDGBuilder& GraphBuilder,
const FSceneView& View,
const FPostProcessingInputs& Inputs) override;
virtual bool IsActiveThisFrame_Internal(
const FSceneViewExtensionContext& Context) const override;
// Registration
void SetEnabled(bool bEnabled) { bIsEnabled = bEnabled; }
void SetSubsystem(UStyleTransferSubsystem* InSubsystem) { Subsystem = InSubsystem; }
private:
TWeakObjectPtr<UWorld> World;
UStyleTransferSubsystem* Subsystem = nullptr;
bool bIsEnabled = true;
};// StyleTransferPass.h
#pragma once
#include "RenderGraphBuilder.h"
#include "ScreenPass.h"
#include "PostProcess/PostProcessing.h"
// Input parameters for the style transfer pass
BEGIN_SHADER_PARAMETER_STRUCT(FStyleTransferPassParameters, )
// Input scene color (from lighting pass)
SHADER_PARAMETER_RDG_TEXTURE(Texture2D, InputSceneColor)
// Optional: depth for edge-aware compositing
SHADER_PARAMETER_RDG_TEXTURE(Texture2D, InputSceneDepth)
// Optional: stencil/mask for per-object stylization
SHADER_PARAMETER_RDG_TEXTURE(Texture2D, InputObjectMask)
// Output stylized texture
RENDER_TARGET_BINDING_SLOTS()
END_SHADER_PARAMETER_STRUCT()
// Intermediate data passed between extraction and composite
struct FStyleTransferIntermediates
{
// Downsampled input for CNN
FRDGTextureRef DownsampledInput;
// Output from CNN inference
FRDGTextureRef StylizedOutput;
// Region of interest in original resolution
FIntRect ROI;
// Blend factor
float StyleStrength = 1.0f;
};
// The main pass that orchestrates style transfer
class FStyleTransferPass
{
public:
FStyleTransferPass(UStyleTransferSubsystem* InSubsystem);
// Add passes to the render graph
void AddPasses(
FRDGBuilder& GraphBuilder,
const FSceneView& View,
const FPostProcessingInputs& Inputs,
FScreenPassTexture& InOutSceneColor);
private:
// Step 1: Extract and downsample ROI
FStyleTransferIntermediates ExtractAndDownsample(
FRDGBuilder& GraphBuilder,
FScreenPassTexture& SceneColor,
const FIntRect& ROI,
FIntPoint TargetSize);
// Step 2: Run ONNX inference (async, may use previous frame)
void RunInference(
FRDGBuilder& GraphBuilder,
FStyleTransferIntermediates& Intermediates);
// Step 3: Upsample and composite back
void UpsampleAndComposite(
FRDGBuilder& GraphBuilder,
const FStyleTransferIntermediates& Intermediates,
FScreenPassTexture& InOutSceneColor);
UStyleTransferSubsystem* Subsystem;
};// ONNXRuntimeWrapper.h
#pragma once
#include "CoreMinimal.h"
// ONNX Runtime C API
#include "onnxruntime_c_api.h"
/**
* Low-level wrapper around ONNX Runtime C API.
* Handles session creation, tensor I/O, and inference execution.
*/
class FONNXRuntimeWrapper
{
public:
FONNXRuntimeWrapper();
~FONNXRuntimeWrapper();
// Initialize the ONNX Runtime environment
bool Initialize(bool bUseDirectML = true);
// Shutdown and release resources
void Shutdown();
// Load a model and create a session
bool LoadModel(const FString& ModelPath, FName SessionId);
// Unload a session
void UnloadModel(FName SessionId);
// Synchronous inference (blocks until complete)
// InputData: NHWC float32 tensor, normalized [0,1]
// OutputData: NHWC float32 tensor, normalized [0,1]
bool RunInference(
FName SessionId,
const TArray<float>& InputData,
int32 BatchSize,
int32 Height,
int32 Width,
int32 Channels,
TArray<float>& OutputData);
// Get model input/output dimensions
bool GetModelDimensions(
FName SessionId,
int32& OutHeight,
int32& OutWidth,
int32& OutChannels) const;
// Check if DirectML is available
static bool IsDirectMLAvailable();
private:
const OrtApi* Api = nullptr;
OrtEnv* Environment = nullptr;
OrtMemoryInfo* MemoryInfo = nullptr;
OrtSessionOptions* SessionOptions = nullptr;
TMap<FName, OrtSession*> Sessions;
FCriticalSection Lock;
void LogOrtError(OrtStatus* Status, const TCHAR* Context);
};// StylizedObjectComponent.h
#pragma once
#include "CoreMinimal.h"
#include "Components/ActorComponent.h"
#include "StylizedObjectComponent.generated.h"
UCLASS(ClassGroup=(Rendering), meta=(BlueprintSpawnableComponent))
class YOURGAME_API UStylizedObjectComponent : public UActorComponent
{
GENERATED_BODY()
public:
UStylizedObjectComponent();
// Style settings
UPROPERTY(EditAnywhere, BlueprintReadWrite, Category = "Style Transfer")
FName StyleModelId;
UPROPERTY(EditAnywhere, BlueprintReadWrite, Category = "Style Transfer", meta = (ClampMin = "0.0", ClampMax = "1.0"))
float StyleStrength = 1.0f;
UPROPERTY(EditAnywhere, BlueprintReadWrite, Category = "Style Transfer", meta = (ClampMin = "1", ClampMax = "60"))
int32 UpdateEveryNFrames = 1;
UPROPERTY(EditAnywhere, BlueprintReadWrite, Category = "Style Transfer")
bool bEnabled = true;
// Runtime state
UFUNCTION(BlueprintCallable, Category = "Style Transfer")
FBox GetWorldBounds() const;
UFUNCTION(BlueprintCallable, Category = "Style Transfer")
void SetTargetStyle(FName NewStyleId, float TransitionDuration = 0.5f);
// Called by the render system to get current interpolated style
float GetCurrentStyleStrength() const;
FName GetCurrentStyleId() const;
protected:
virtual void BeginPlay() override;
virtual void TickComponent(float DeltaTime, ELevelTick TickType, FActorComponentTickFunction* ThisTickFunction) override;
private:
// For smooth transitions
FName PreviousStyleId;
FName TargetStyleId;
float TransitionAlpha = 1.0f;
float TransitionSpeed = 2.0f;
int32 FrameCounter = 0;
};FRAME N DATA FLOW
═════════════════
Game Thread Render Thread GPU
───────────────────────────────────────────────────────────────────────────────
1. Tick Components
┌─────────────────────┐
│ UStylizedObject │
│ Component::Tick() │
│ │
│ • Update style lerp │
│ • Check frame skip │
└──────────┬──────────┘
│
│ (via Scene Proxy / View Extension)
│
▼
2. PrePostProcessPass_RenderThread
┌─────────────────────────┐
│ FStyleTransferView │
│ Extension::PrePost...() │
│ │
│ • Query visible styled │
│ objects │
│ • Compute ROIs │
└───────────┬─────────────┘
│
▼
3. Add RDG Passes
┌─────────────────────────┐
│ FStyleTransferPass:: │
│ AddPasses() │
│ │
│ a) Downsample Pass │──────▶ [GPU: Bilinear downsample]
│ (SceneColor → 256²) │
│ │
│ b) Copy to Staging │──────▶ [GPU→CPU readback]
│ (RDG → CPU buffer) │ (or GPU-to-GPU via DML)
│ │
│ c) ONNX Inference │──────▶ [GPU: DirectML EP]
│ (via wrapper) │ ┌─────────────────┐
│ │ │ CNN Forward Pass│
│ │ │ ~1.7M params │
│ │ │ ~8ms @ 256² │
│ │ └─────────────────┘
│ d) Copy from Staging │◀────── [CPU→GPU upload]
│ (CPU → RDG texture) │ (or GPU-to-GPU)
│ │
│ e) Upsample + Composite │──────▶ [GPU: Bilinear + blend]
│ (Stylized → Scene) │
└─────────────────────────┘
│
▼
4. Continue Pipeline
┌─────────────────────────┐
│ Tonemapping, UI, etc. │──────▶ [GPU: Post-process chain]
└─────────────────────────┘TEXTURE FORMATS & SIZES
═══════════════════════
┌─────────────────────────────────────────────────────────────────────────────┐
│ Resource │ Format │ Size │ Usage │
├───────────────────────┼──────────────────┼─────────────┼────────────────────┤
│ SceneColor (Input) │ PF_FloatRGBA │ Viewport │ SRV (read) │
│ DownsampledInput │ PF_FloatRGBA │ 256×256 │ SRV + UAV │
│ ONNXInputStaging │ PF_R32G32B32A32F │ 256×256 │ CPU-visible │
│ ONNXOutputStaging │ PF_R32G32B32A32F │ 256×256 │ CPU-visible │
│ StylizedOutput │ PF_FloatRGBA │ 256×256 │ SRV │
│ UpsampledStylized │ PF_FloatRGBA │ ROI size │ SRV │
│ FinalComposite │ PF_FloatRGBA │ Viewport │ RTV (write) │
└───────────────────────┴──────────────────┴─────────────┴────────────────────┘
MEMORY ESTIMATE (per styled object, 256×256 pipeline):
Input staging: 256 × 256 × 4 × 4 bytes = 1 MB
Output staging: 256 × 256 × 4 × 4 bytes = 1 MB
Intermediates: ~2 MB
ONNX model: ~1.7 MB (mosaic-9.onnx)
─────────────────────────────────────────────
Total per object: ~6 MB
Total (3 objects): ~18 MB + shared overheadTHREAD RESPONSIBILITIES
═══════════════════════
┌─────────────────────────────────────────────────────────────────────────────┐
│ │
│ GAME THREAD │
│ ─────────────────────────────────────────────────────────────────────── │
│ • UStyleTransferSubsystem lifecycle │
│ • Model loading/unloading (async task) │
│ • Component tick (style parameter updates) │
│ • Blueprint API │
│ │
│ RENDER THREAD │
│ ─────────────────────────────────────────────────────────────────────── │
│ • View extension callbacks │
│ • RDG pass registration │
│ • ROI computation │
│ • Shader dispatch │
│ │
│ RHI THREAD │
│ ─────────────────────────────────────────────────────────────────────── │
│ • Actual GPU command submission │
│ • Resource transitions │
│ • Staging buffer copies │
│ │
│ ASYNC INFERENCE (optional, for latency hiding) │
│ ─────────────────────────────────────────────────────────────────────── │
│ • ONNX Runtime inference on dedicated thread │
│ • Double-buffered input/output │
│ • 1-frame latency tradeoff for throughput │
│ │
└─────────────────────────────────────────────────────────────────────────────┘
SYNCHRONIZATION POINTS:
[Game Thread] [Render Thread]
│ │
│◀──── FCriticalSection ────────▶│ (session access)
│ │
│ │
│ FRenderCommandFence │
│◀───────────────────────────────│ (readback sync)
│ │BUILD DEPENDENCIES (YourGame.Build.cs)
══════════════════════════════════════
PublicDependencyModuleNames:
├── Core
├── CoreUObject
├── Engine
├── RenderCore
├── Renderer
├── RHI
└── Projects
PrivateDependencyModuleNames:
├── RenderGraph // FRDGBuilder, FRDGTexture
├── Renderer // FSceneViewExtensionBase
└── ONNXRuntime // External: onnxruntime.dll
Third Party:
└── onnxruntime (1.16+)
├── onnxruntime.lib
├── onnxruntime.dll
├── onnxruntime_providers_shared.dll
└── onnxruntime_providers_dml.dll (DirectML EP)
HEADER INCLUDES:
• "onnxruntime_c_api.h"
• "dml_provider_factory.h" (for DirectML EP registration)Source/YourGame/
├── YourGame.Build.cs # Add ONNX Runtime dependency
│
├── StyleTransfer/
│ ├── Public/
│ │ ├── StyleTransferSubsystem.h # Game instance subsystem
│ │ ├── StyleTransferViewExtension.h # Render thread hook
│ │ ├── StyleTransferPass.h # RDG pass definitions
│ │ ├── StylizedObjectComponent.h # Actor component
│ │ ├── StyleZoneTrigger.h # Zone-based style switching
│ │ └── ONNXRuntimeWrapper.h # C API wrapper
│ │
│ └── Private/
│ ├── StyleTransferSubsystem.cpp
│ ├── StyleTransferViewExtension.cpp
│ ├── StyleTransferPass.cpp
│ ├── StylizedObjectComponent.cpp
│ ├── StyleZoneTrigger.cpp
│ └── ONNXRuntimeWrapper.cpp
│
├── Shaders/ # (Optional: custom HLSL)
│ ├── StyleTransferDownsample.usf
│ └── StyleTransferComposite.usf
│
Content/
├── StyleTransfer/
│ ├── Models/
│ │ ├── mosaic-9.onnx
│ │ ├── candy-9.onnx
│ │ └── udnie-9.onnx
│ │
│ └── Demo/
│ └── SculptureHall.umap
│
Binaries/Win64/
├── onnxruntime.dll
├── onnxruntime_providers_shared.dll
└── onnxruntime_providers_dml.dll| Risk | Likelihood | Impact | Mitigation |
|---|---|---|---|
| DirectML interop with UE5 RHI | Medium | High | Start with CPU inference, add DML later |
| Staging buffer copy latency | High | Medium | Accept 1-frame latency, async inference |
| ONNX Runtime DLL loading | Low | High | Use delay-load, explicit LoadLibrary |
| RDG resource lifetime | Medium | Medium | Careful pass ordering, ExtractTexture |
| Thread safety (session access) | Medium | High | FCriticalSection around all ORT calls |